Skip to content

Commit 112c45e

Browse files
author
Aegis AI Assistant
committed
fix(gemma4): restore multimodal scale magnitudes and stabilize token wrappers
1 parent bc9c956 commit 112c45e

1 file changed

Lines changed: 23 additions & 4 deletions

File tree

Libraries/MLXVLM/Models/Gemma4VL.swift

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ private class Gemma4PatchEmbedder: Module {
223223
outputChannels: config.hiddenSize,
224224
kernelSize: [config.patchSize, config.patchSize],
225225
stride: [config.patchSize, config.patchSize],
226-
bias: true
226+
bias: false
227227
)
228228
// Set the parameter directly. MLX requires Module parameters to be either Module or explicitly managed MLXArrays.
229229
self.position_embedding_table = zeros([2, 10240, config.hiddenSize])
@@ -278,7 +278,7 @@ private class Gemma4Projector: Module, UnaryLayer {
278278
@ModuleInfo(key: "embedding_projection") var projection: any UnaryLayer
279279

280280
init(visionDim: Int, textDim: Int) {
281-
self._projection.wrappedValue = Linear(visionDim, textDim)
281+
self._projection.wrappedValue = Linear(visionDim, textDim, bias: false)
282282
super.init()
283283
}
284284

@@ -478,6 +478,19 @@ public class Gemma4VL: Module, VLMModel, KVCacheDimensionProvider, LayerPartitio
478478
continue
479479
}
480480

481+
// Fix MLX Conv2d spatial scrambling: Hugging Face SigLIP patch extractors are often linear layers
482+
// which export flat [O, C*H*W] -> [768, 768] weights. MLX Conv2d correctly expects [O, H, W, C].
483+
// We must reshape to PyTorch native [O, C, H, W] then transpose to MLX Conv native!
484+
if newK == "vision_tower.patch_embedder.input_proj.weight" && v.ndim == 2 {
485+
let outChannels = v.dim(0)
486+
let spatialChannels = v.dim(1)
487+
if spatialChannels == visionConfig.numChannels * visionConfig.patchSize * visionConfig.patchSize {
488+
let reshaped = v.reshaped([outChannels, visionConfig.numChannels, visionConfig.patchSize, visionConfig.patchSize])
489+
processed[newK] = reshaped.transposed(0, 2, 3, 1)
490+
continue
491+
}
492+
}
493+
481494
processed[newK] = v
482495
}
483496
}
@@ -600,7 +613,10 @@ public struct Gemma4Processor: UserInputProcessor {
600613

601614
// If the chat template completely dropped the image tokens, inject them manually!
602615
if expandedTokens.count == promptTokens.count && !promptTokens.contains(imageTokenId) {
603-
let imagePad = Array(repeating: imageTokenId, count: numTokens)
616+
var imagePad = [255999] // <|image>
617+
imagePad.append(contentsOf: Array(repeating: imageTokenId, count: numTokens))
618+
imagePad.append(258882) // <image|>
619+
604620
if expandedTokens.first == 2 {
605621
// Inject right after BOS (2)
606622
expandedTokens.insert(contentsOf: imagePad, at: 1)
@@ -638,10 +654,13 @@ public struct Gemma4Processor: UserInputProcessor {
638654
let expectedAudioTokens = layer1Length
639655

640656
var expandedTokens = promptTokens
641-
let audioPadding = Array(repeating: audioTokenId, count: expectedAudioTokens)
642657
let gemmaBoa = 256000 // <|audio>
643658
let gemmaEoa = 258883 // <audio|>
644659

660+
var audioPadding = [gemmaBoa]
661+
audioPadding.append(contentsOf: Array(repeating: audioTokenId, count: expectedAudioTokens))
662+
audioPadding.append(gemmaEoa)
663+
645664
// The MessageGenerator injected <|audio|> strings which tokenizer resolves to audioTokenId (258881)
646665
// Determine insertion point before wiping ALL instances.
647666
let targetIdx = expandedTokens.firstIndex(of: gemmaBoa) ?? expandedTokens.firstIndex(of: audioTokenId)

0 commit comments

Comments
 (0)