Skip to content

Commit 1ff327c

Browse files
author
Aegis-AI
committed
fix(sync): resolve overlapping omni-model tied word embedding logic
1 parent 640b63f commit 1ff327c

2 files changed

Lines changed: 17 additions & 7 deletions

File tree

Libraries/MLXVLM/Models/Gemma4VL.swift

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,11 @@ public class Gemma4VL: Module, VLMModel, KVCacheDimensionProvider, LayerPartitio
517517
// This MUST be done because we explicitly allocated a separate lm_head linear layer!
518518
if processed["lm_head.weight"] == nil || config.tieWordEmbeddings {
519519
// Check both prefixed and flat keys to be robust against different sanitization outputs
520-
let embedKeys = ["model.embed_tokens.weight", "embed_tokens.weight", "model.embedTokens.weight", "embedTokens.weight"]
520+
let prefix = "language_model."
521+
let embedKeys = [
522+
"\(prefix)model.embed_tokens.weight", "\(prefix)embed_tokens.weight", "\(prefix)model.embedTokens.weight", "\(prefix)embedTokens.weight",
523+
"model.embed_tokens.weight", "embed_tokens.weight", "model.embedTokens.weight", "embedTokens.weight"
524+
]
521525

522526
for key in embedKeys {
523527
if let embedWeights = processed[key] {
@@ -528,15 +532,21 @@ public class Gemma4VL: Module, VLMModel, KVCacheDimensionProvider, LayerPartitio
528532
}
529533

530534
// Repeat for scales/biases if present (quantized models)
531-
let scaleKeys = ["model.embed_tokens.scales", "embed_tokens.scales", "model.embedTokens.scales", "embedTokens.scales"]
535+
let scaleKeys = [
536+
"\(prefix)model.embed_tokens.scales", "\(prefix)embed_tokens.scales", "\(prefix)model.embedTokens.scales", "\(prefix)embedTokens.scales",
537+
"model.embed_tokens.scales", "embed_tokens.scales", "model.embedTokens.scales", "embedTokens.scales"
538+
]
532539
for key in scaleKeys {
533540
if let embedScales = processed[key] {
534541
processed["lm_head.scales"] = embedScales
535542
break
536543
}
537544
}
538545

539-
let biasKeys = ["model.embed_tokens.biases", "embed_tokens.biases", "model.embedTokens.biases", "embedTokens.biases"]
546+
let biasKeys = [
547+
"\(prefix)model.embed_tokens.biases", "\(prefix)embed_tokens.biases", "\(prefix)model.embedTokens.biases", "\(prefix)embedTokens.biases",
548+
"model.embed_tokens.biases", "embed_tokens.biases", "model.embedTokens.biases", "embedTokens.biases"
549+
]
540550
for key in biasKeys {
541551
if let embedBiases = processed[key] {
542552
processed["lm_head.biases"] = embedBiases

Tests/MLXLMTests/Gemma4Tests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct Gemma4Tests {
4141
@Test("Gemma 4 Forward Pass - Determinism & Shape")
4242
func testGemma4ForwardPass() throws {
4343
let config = makeTinyConfig()
44-
let model = Gemma4Model(config)
44+
let model = Gemma4ModelInternal(config)
4545

4646
let input = MLXArray(0..<8).reshaped(1, 8)
4747
let output = model(input)
@@ -56,7 +56,7 @@ struct Gemma4Tests {
5656
@Test("PLE Multimodal Signal Integrity")
5757
func testPLESignalIntegrity() throws {
5858
let config = makeTinyConfig()
59-
let model = Gemma4Model(config)
59+
let model = Gemma4ModelInternal(config)
6060

6161
let input = MLXArray(Int32(0)..<Int32(5)).reshaped(1, 5)
6262

@@ -70,7 +70,7 @@ struct Gemma4Tests {
7070
@Test("Weight Sanitization - PLE Mapping")
7171
func testGemma4Sanitization() throws {
7272
let config = makeTinyConfig()
73-
let model = Gemma4Model(config)
73+
let model = Gemma4ModelInternal(config)
7474

7575
var weights = [String: MLXArray]()
7676
weights["model.layers.0.per_layer_conditioning.scale"] = MLXArray.ones([config.hiddenSize, config.hiddenSizePerLayerInput])
@@ -85,7 +85,7 @@ struct Gemma4Tests {
8585
@Test("Audio Configuration Dependency Safety")
8686
func testAudioConfigSafety() throws {
8787
let config = makeTinyConfig()
88-
let model = Gemma4Model(config)
88+
let model = Gemma4ModelInternal(config)
8989
#expect(model.model.layers.count == config.hiddenLayers)
9090
}
9191
}

0 commit comments

Comments
 (0)