Skip to content

Fix Gemma 4 LoRA training: vision backward NaN + audio_tower freeze leak#1052

Merged
Blaizzy merged 4 commits intoBlaizzy:mainfrom
john-rocky:fix/gemma4-lora-training
Apr 24, 2026
Merged

Fix Gemma 4 LoRA training: vision backward NaN + audio_tower freeze leak#1052
Blaizzy merged 4 commits intoBlaizzy:mainfrom
john-rocky:fix/gemma4-lora-training

Conversation

@john-rocky
Copy link
Copy Markdown
Contributor

Problem

Fine-tuning the multimodal google/gemma-4-E2B-it checkpoint with the current SFT + LoRA pipeline hits two Gemma-4-specific bugs that together make training impossible or highly wasteful.


1. vision_tower backward produces NaN on padded images (blocker)

When pixel_values contain fewer real patches than max_patches (any non-square input or smaller-than-max resolution), padded query rows in the bidirectional attention mask are entirely -inf. softmax of an all-(-inf) row is finite in forward (those positions are zeroed out later in the pooler) but during backward produces NaN — exp(-inf - -inf) = exp(NaN) — so every LoRA/LR gradient ends up NaN.

Minimal repro on vanilla google/gemma-4-E2B-it:

```python
import mlx.core as mx, mlx.nn as nn
from mlx_vlm.utils import load
from mlx_vlm.trainer.utils import freeze_model

model, _ = load("google/gemma-4-E2B-it", processor_config={"trust_remote_code": True})
freeze_model(model)
model.vision_tower.unfreeze()

any pixel_values where num_real < max_patches, e.g. a random 960x624 tensor

pv = mx.random.normal(shape=(1, 3, 960, 624))

def f(x):
return model.vision_tower(x).sum()

v_direct = f(pv) # finite
v_vg, _ = nn.value_and_grad(model.vision_tower, f)(pv) # NaN
```

The forward alone is fine, so the symptom only appears once a training loop wraps the forward in nn.value_and_grad — which makes it easy to misdiagnose as a LoRA or dtype bug.

Fix: fill masked positions with a large finite negative (-1e4) instead of -inf. Forward behaviour is unchanged (exp(-1e4) is subnormal and the padded outputs are still zeroed downstream by the pooler), but the softmax gradient stays finite, so LoRA / full FT on Gemma 4 converges normally.


2. freeze_model leaks audio_tower / embed_audio / embed_vision

freeze_model hard-codes a top-level freeze list that pre-dates Gemma 4's audio branch. So after get_peft_model hands the model to the SFT trainer, every audio_tower.* weight (plus embed_audio, embed_vision) is still trainable. Two user-visible consequences:

  • save_adapter dumps model.trainable_parameters(), so adapter files balloon from ~100 MB to ~700 MB per checkpoint with unrelated audio weights.
  • Those weights silently receive gradients (even on image-only batches where no audio flows through), which is surprising and wastes optimizer memory.

Fix: add audio_tower / embed_audio / embed_vision to the top-level freeze list. A fallback is included because some audio sub-modules (notably AudioRelativePositionEmbedding, whose _inv_timescales is a plain buffer) trip nn.Module.freeze(); on exception we walk leaf modules instead.


Verified

With both fixes applied, Gemma 4 E2B LoRA fine-tuning on Unsplash / Pexels photos runs to completion on an M4 Max (500 iters, batch 1, bf16):

  • loss 10.6 → 1.1 (no NaN)
  • adapter file size: ~110 MB per checkpoint (was ~700 MB)
  • LoRA merge + HF-format round-trip works for downstream CoreML conversion

Scope

Two surgical changes; no API changes:

  • mlx_vlm/models/gemma4/vision.py — mask fill value
  • mlx_vlm/trainer/utils.py — freeze-list additions + leaf-walk fallback

Context: I maintain a separate repo where this sequence (LoRA-train Gemma 4 → fuse → CoreML bundle → iPhone) is productized; happy to split into two PRs if you prefer.

Two Gemma-4-specific bugs that block LoRA fine-tuning on the multimodal
checkpoint.

### 1. vision_tower produces NaN gradients on padded images

When pixel_values contain fewer real patches than ``max_patches`` (any
non-square input, or a smaller-than-max resolution), the padded query rows
in the bidirectional attention mask are entirely -inf. softmax of an
all-(-inf) row is finite in forward (it is zeroed out downstream by the
pooler) but during backward produces NaN: ``exp(-inf - -inf) = exp(NaN)``.

Minimal repro on ``google/gemma-4-E2B-it`` (rescaled image with padding):

    model, _ = load("google/gemma-4-E2B-it")
    freeze_model(model); model.vision_tower.unfreeze()
    pv = ... # any pixel_values where num_real < max_patches
    def f(x): return model.vision_tower(x).sum()
    val, grad = nn.value_and_grad(model.vision_tower, f)(pv)
    # val is NaN; every grad leaf is NaN

Forward alone is fine (finite result), so the symptom only appears once a
training loop calls ``nn.value_and_grad``, which makes it easy to misread
as a LoRA / dtype bug.

Fix: fill masked positions with a large finite negative (``-1e4``) in the
image-embed dtype. The forward behaviour is unchanged (exp(-1e4) is
subnormal and the padded outputs are still zeroed in the pooler), but the
gradient through softmax stays finite, so LoRA / full FT on Gemma 4 works.

### 2. freeze_model leaks audio_tower / embed_audio / embed_vision

``freeze_model`` only covers a fixed set of top-level module names that
pre-dates the Gemma-4 audio branch. As a result every ``audio_tower.*``
weight (plus ``embed_audio`` / ``embed_vision``) is still trainable when
``get_peft_model`` hands the model to the SFT trainer. Two user-facing
consequences:

- ``save_adapter`` dumps ``model.trainable_parameters()``, so the saved
  LoRA file grows from ~100 MB to ~700 MB with unrelated audio weights.
- Those weights silently receive gradients (though no real audio input
  flows through them on an image-only batch), which is surprising.

Fix: include ``audio_tower`` / ``embed_audio`` / ``embed_vision`` in the
top-level freeze list. Some audio sub-modules (notably
``AudioRelativePositionEmbedding`` in Gemma 4, whose ``_inv_timescales``
is a plain buffer) crash ``freeze()``, so fall back to walking leaf
modules on error.

### Verified

With both fixes applied, Gemma 4 E2B LoRA fine-tuning on Unsplash / Pexels
photos runs to completion on an M4 Max (500 iters, batch 1, bf16):
loss 10.6 → 1.1, adapter file ~110 MB per checkpoint, no NaN.
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, could you trim the comments? Then we can merge

@john-rocky
Copy link
Copy Markdown
Contributor Author

Trimmed the comments in b01c289 — kept only the non-obvious WHY (the NaN-in-backward rationale on the mask, and the freeze() fallback). PTAL.

Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@Blaizzy Blaizzy merged commit 0f903f9 into Blaizzy:main Apr 24, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants