@@ -1378,9 +1378,15 @@ private final class Gemma4VisionPooler: Module {
13781378 validCount: Int ,
13791379 outputLength: Int ? = nil
13801380 ) -> MLXArray {
1381+ let paddingPositions = patchPositions [ 0 ... , 0 ... , 0 ] .< 0
1382+ let pooledHiddenStates = MLX . where (
1383+ expandedDimensions ( paddingPositions, axis: - 1 ) ,
1384+ MLXArray ( 0.0 , dtype: hiddenStates. dtype) ,
1385+ hiddenStates
1386+ )
13811387 let length = outputLength ?? defaultOutputLength
1382- if hiddenStates . dim ( 1 ) <= length {
1383- return hiddenStates * MLXArray( rootHiddenSize, dtype: hiddenStates . dtype)
1388+ if pooledHiddenStates . dim ( 1 ) <= length {
1389+ return pooledHiddenStates * MLXArray( rootHiddenSize, dtype: pooledHiddenStates . dtype)
13841390 }
13851391
13861392 let actualPositions = patchPositions [ 0 , ..< validCount]
@@ -1397,9 +1403,10 @@ private final class Gemma4VisionPooler: Module {
13971403 let weights =
13981404 gemma4OneHot ( flatKernel, numClasses: pooledLength) . asType ( . float32)
13991405 / Float( divisor)
1400- let output = einsum ( " lL,bld->bLd " , weights, hiddenStates [ 0 ... , ..< validCount, 0 ... ] )
1401- . asType ( hiddenStates. dtype)
1402- return output * MLXArray( rootHiddenSize, dtype: hiddenStates. dtype)
1406+ let output = einsum (
1407+ " lL,bld->bLd " , weights, pooledHiddenStates [ 0 ... , ..< validCount, 0 ... ] )
1408+ . asType ( pooledHiddenStates. dtype)
1409+ return output * MLXArray( rootHiddenSize, dtype: pooledHiddenStates. dtype)
14031410 }
14041411}
14051412
@@ -1523,17 +1530,17 @@ private final class Gemma4VisionModel: Module {
15231530
15241531private final class Gemma4MultimodalEmbedder : Module , UnaryLayer {
15251532 @ModuleInfo ( key: " embedding_projection " ) var embeddingProjection : Linear
1526- @ModuleInfo ( key: " embedding_post_projection_norm " ) var embeddingPostProjectionNorm :
1533+ @ModuleInfo ( key: " embedding_pre_projection_norm " ) var embeddingPreProjectionNorm :
15271534 Gemma4RMSNormNoScale
15281535
15291536 init ( embeddingDim: Int , textHiddenSize: Int , eps: Float ) {
15301537 self . _embeddingProjection. wrappedValue = Linear ( embeddingDim, textHiddenSize, bias: false )
1531- self . _embeddingPostProjectionNorm . wrappedValue = Gemma4RMSNormNoScale ( eps: eps)
1538+ self . _embeddingPreProjectionNorm . wrappedValue = Gemma4RMSNormNoScale ( eps: eps)
15321539 super. init ( )
15331540 }
15341541
15351542 func callAsFunction( _ x: MLXArray ) -> MLXArray {
1536- embeddingPostProjectionNorm ( embeddingProjection ( x) )
1543+ embeddingProjection ( embeddingPreProjectionNorm ( x) )
15371544 }
15381545}
15391546
0 commit comments