@@ -223,7 +223,7 @@ private class Gemma4PatchEmbedder: Module {
223223 outputChannels: config. hiddenSize,
224224 kernelSize: [ config. patchSize, config. patchSize] ,
225225 stride: [ config. patchSize, config. patchSize] ,
226- bias: true
226+ bias: false
227227 )
228228 // Set the parameter directly. MLX requires Module parameters to be either Module or explicitly managed MLXArrays.
229229 self . position_embedding_table = zeros ( [ 2 , 10240 , config. hiddenSize] )
@@ -278,7 +278,7 @@ private class Gemma4Projector: Module, UnaryLayer {
278278 @ModuleInfo ( key: " embedding_projection " ) var projection : any UnaryLayer
279279
280280 init ( visionDim: Int , textDim: Int ) {
281- self . _projection. wrappedValue = Linear ( visionDim, textDim)
281+ self . _projection. wrappedValue = Linear ( visionDim, textDim, bias : false )
282282 super. init ( )
283283 }
284284
@@ -478,6 +478,19 @@ public class Gemma4VL: Module, VLMModel, KVCacheDimensionProvider, LayerPartitio
478478 continue
479479 }
480480
481+ // Fix MLX Conv2d spatial scrambling: Hugging Face SigLIP patch extractors are often linear layers
482+ // which export flat [O, C*H*W] -> [768, 768] weights. MLX Conv2d correctly expects [O, H, W, C].
483+ // We must reshape to PyTorch native [O, C, H, W] then transpose to MLX Conv native!
484+ if newK == " vision_tower.patch_embedder.input_proj.weight " && v. ndim == 2 {
485+ let outChannels = v. dim ( 0 )
486+ let spatialChannels = v. dim ( 1 )
487+ if spatialChannels == visionConfig. numChannels * visionConfig. patchSize * visionConfig. patchSize {
488+ let reshaped = v. reshaped ( [ outChannels, visionConfig. numChannels, visionConfig. patchSize, visionConfig. patchSize] )
489+ processed [ newK] = reshaped. transposed ( 0 , 2 , 3 , 1 )
490+ continue
491+ }
492+ }
493+
481494 processed [ newK] = v
482495 }
483496 }
@@ -600,7 +613,10 @@ public struct Gemma4Processor: UserInputProcessor {
600613
601614 // If the chat template completely dropped the image tokens, inject them manually!
602615 if expandedTokens. count == promptTokens. count && !promptTokens. contains ( imageTokenId) {
603- let imagePad = Array ( repeating: imageTokenId, count: numTokens)
616+ var imagePad = [ 255999 ] // <|image>
617+ imagePad. append ( contentsOf: Array ( repeating: imageTokenId, count: numTokens) )
618+ imagePad. append ( 258882 ) // <image|>
619+
604620 if expandedTokens. first == 2 {
605621 // Inject right after BOS (2)
606622 expandedTokens. insert ( contentsOf: imagePad, at: 1 )
@@ -638,10 +654,13 @@ public struct Gemma4Processor: UserInputProcessor {
638654 let expectedAudioTokens = layer1Length
639655
640656 var expandedTokens = promptTokens
641- let audioPadding = Array ( repeating: audioTokenId, count: expectedAudioTokens)
642657 let gemmaBoa = 256000 // <|audio>
643658 let gemmaEoa = 258883 // <audio|>
644659
660+ var audioPadding = [ gemmaBoa]
661+ audioPadding. append ( contentsOf: Array ( repeating: audioTokenId, count: expectedAudioTokens) )
662+ audioPadding. append ( gemmaEoa)
663+
645664 // The MessageGenerator injected <|audio|> strings which tokenizer resolves to audioTokenId (258881)
646665 // Determine insertion point before wiping ALL instances.
647666 let targetIdx = expandedTokens. firstIndex ( of: gemmaBoa) ?? expandedTokens. firstIndex ( of: audioTokenId)
0 commit comments