Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions mlx_vlm/models/gemma4/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,18 @@ def __call__(self, pixel_values) -> mx.array:
patch_positions = mx.concatenate(all_positions, axis=0)
padding_positions = mx.concatenate(all_padding, axis=0)

# Build bidirectional attention mask [B, 1, L, L] for SDPA
# Build bidirectional attention mask [B, 1, L, L] for SDPA.
# Use a large finite negative mask value rather than -inf. When the
# input contains padded patches, queries at those positions see all
# keys masked; an all -inf softmax row produces NaN during backward
# (exp(-inf - -inf) = NaN), even though the forward value is finite
# because the padded outputs are zeroed downstream. A finite mask
# keeps the gradient defined, which is required for LoRA / fine-tuning.
valid_mask = ~padding_positions
attn_mask = mx.expand_dims(valid_mask, 1) * mx.expand_dims(valid_mask, 2)
neg_inf = mx.array(float("-inf"), dtype=inputs_embeds.dtype)
mask_fill = mx.array(-1e4, dtype=inputs_embeds.dtype)
attn_mask = mx.where(
attn_mask, mx.array(0.0, dtype=inputs_embeds.dtype), neg_inf
attn_mask, mx.array(0.0, dtype=inputs_embeds.dtype), mask_fill
)
attn_mask = mx.expand_dims(attn_mask, 1)

Expand Down
41 changes: 31 additions & 10 deletions mlx_vlm/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,39 @@ def get_peft_model(


def freeze_model(model):
# Gemma 4 (and other multimodal models with an audio tower) also expose
# ``audio_tower`` / ``embed_audio`` / ``embed_vision`` at the top level;
# omitting them leaves hundreds of MB of unrelated weights marked
# trainable, which bloats adapter files and invites gradient leakage.
top_level_to_freeze = {
"language_model",
"vision_model",
"vision_tower",
"aligner",
"connector",
"multi_modal_projector",
"mm_projector",
"audio_tower",
"embed_audio",
"embed_vision",
}
for name, module in model.named_modules():
name = name.split(".")[0]
if name in [
"language_model",
"vision_model",
"vision_tower",
"aligner",
"connector",
"multi_modal_projector",
"mm_projector",
]:
model[f"{name}"].freeze()
if name in top_level_to_freeze and hasattr(model, name):
try:
model[f"{name}"].freeze()
except Exception:
# Some multimodal towers have custom sub-modules (e.g.
# AudioRelativePositionEmbedding) whose ``freeze`` errors out
# on non-Module buffers. Fall back to walking leaf modules.
try:
from mlx.utils import tree_flatten
top = model[f"{name}"]
leaves = tree_flatten(top.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module))
for _, m in leaves:
m.freeze(recurse=False)
except Exception:
pass


def find_all_linear_names(model):
Expand Down