Skip to content

Commit 67b5729

Browse files
committed
Fixes for gemma 4 vision
1 parent 3eb7527 commit 67b5729

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

Libraries/MLXVLM/Models/Gemma4.swift

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,9 +1378,15 @@ private final class Gemma4VisionPooler: Module {
13781378
validCount: Int,
13791379
outputLength: Int? = nil
13801380
) -> MLXArray {
1381+
let paddingPositions = patchPositions[0..., 0..., 0] .< 0
1382+
let pooledHiddenStates = MLX.where(
1383+
expandedDimensions(paddingPositions, axis: -1),
1384+
MLXArray(0.0, dtype: hiddenStates.dtype),
1385+
hiddenStates
1386+
)
13811387
let length = outputLength ?? defaultOutputLength
1382-
if hiddenStates.dim(1) <= length {
1383-
return hiddenStates * MLXArray(rootHiddenSize, dtype: hiddenStates.dtype)
1388+
if pooledHiddenStates.dim(1) <= length {
1389+
return pooledHiddenStates * MLXArray(rootHiddenSize, dtype: pooledHiddenStates.dtype)
13841390
}
13851391

13861392
let actualPositions = patchPositions[0, ..<validCount]
@@ -1397,9 +1403,10 @@ private final class Gemma4VisionPooler: Module {
13971403
let weights =
13981404
gemma4OneHot(flatKernel, numClasses: pooledLength).asType(.float32)
13991405
/ Float(divisor)
1400-
let output = einsum("lL,bld->bLd", weights, hiddenStates[0..., ..<validCount, 0...])
1401-
.asType(hiddenStates.dtype)
1402-
return output * MLXArray(rootHiddenSize, dtype: hiddenStates.dtype)
1406+
let output = einsum(
1407+
"lL,bld->bLd", weights, pooledHiddenStates[0..., ..<validCount, 0...])
1408+
.asType(pooledHiddenStates.dtype)
1409+
return output * MLXArray(rootHiddenSize, dtype: pooledHiddenStates.dtype)
14031410
}
14041411
}
14051412

@@ -1523,17 +1530,17 @@ private final class Gemma4VisionModel: Module {
15231530

15241531
private final class Gemma4MultimodalEmbedder: Module, UnaryLayer {
15251532
@ModuleInfo(key: "embedding_projection") var embeddingProjection: Linear
1526-
@ModuleInfo(key: "embedding_post_projection_norm") var embeddingPostProjectionNorm:
1533+
@ModuleInfo(key: "embedding_pre_projection_norm") var embeddingPreProjectionNorm:
15271534
Gemma4RMSNormNoScale
15281535

15291536
init(embeddingDim: Int, textHiddenSize: Int, eps: Float) {
15301537
self._embeddingProjection.wrappedValue = Linear(embeddingDim, textHiddenSize, bias: false)
1531-
self._embeddingPostProjectionNorm.wrappedValue = Gemma4RMSNormNoScale(eps: eps)
1538+
self._embeddingPreProjectionNorm.wrappedValue = Gemma4RMSNormNoScale(eps: eps)
15321539
super.init()
15331540
}
15341541

15351542
func callAsFunction(_ x: MLXArray) -> MLXArray {
1536-
embeddingPostProjectionNorm(embeddingProjection(x))
1543+
embeddingProjection(embeddingPreProjectionNorm(x))
15371544
}
15381545
}
15391546

0 commit comments

Comments
 (0)