Skip to content

Commit 7fb05f7

Browse files
authored
Merge pull request #21 from SharpAI/fix/gemma4-moe-unhandled-keys
Fix: Gemma 4 MoE Loading Failure (gemma4_text)
2 parents cf3cf2c + 35b3154 commit 7fb05f7

2 files changed

Lines changed: 217 additions & 2 deletions

File tree

Libraries/MLXLLM/Models/Gemma4Text.swift

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ public struct Gemma4TextConfiguration: Codable, Sendable {
3333
var attentionKeqV: Bool = false
3434
var finalLogitSoftcapping: Float = 30.0
3535
var useDoubleWideMlp: Bool = true
36+
var enableMoEBlock: Bool = false
37+
var numExperts: Int?
38+
var topKExperts: Int?
39+
var moeIntermediateSize: Int?
3640
var layerTypes: [String] = []
3741
var tieWordEmbeddings: Bool = true
3842

@@ -66,6 +70,10 @@ public struct Gemma4TextConfiguration: Codable, Sendable {
6670
case attentionKeqV = "attention_k_eq_v"
6771
case finalLogitSoftcapping = "final_logit_softcapping"
6872
case useDoubleWideMlp = "use_double_wide_mlp"
73+
case enableMoEBlock = "enable_moe_block"
74+
case numExperts = "num_experts"
75+
case topKExperts = "top_k_experts"
76+
case moeIntermediateSize = "moe_intermediate_size"
6977
case layerTypes = "layer_types"
7078
case tieWordEmbeddings = "tie_word_embeddings"
7179
case ropeParameters = "rope_parameters"
@@ -110,6 +118,14 @@ public struct Gemma4TextConfiguration: Codable, Sendable {
110118
try container.decodeIfPresent(Float.self, forKey: .finalLogitSoftcapping) ?? 30.0
111119
self.useDoubleWideMlp =
112120
try container.decodeIfPresent(Bool.self, forKey: .useDoubleWideMlp) ?? true
121+
self.enableMoEBlock =
122+
try container.decodeIfPresent(Bool.self, forKey: .enableMoEBlock) ?? false
123+
self.numExperts =
124+
try container.decodeIfPresent(Int.self, forKey: .numExperts)
125+
self.topKExperts =
126+
try container.decodeIfPresent(Int.self, forKey: .topKExperts)
127+
self.moeIntermediateSize =
128+
try container.decodeIfPresent(Int.self, forKey: .moeIntermediateSize)
113129
if let decoded = try container.decodeIfPresent([String].self, forKey: .layerTypes) {
114130
self.layerTypes = decoded
115131
} else {
@@ -374,6 +390,89 @@ private class Gemma4MLP: Module {
374390
}
375391
}
376392

393+
// MARK: - MoE Router
394+
395+
private class Gemma4TextRouter: Module {
396+
let topKExperts: Int
397+
let rootSize: Float
398+
399+
@ModuleInfo(key: "norm") var norm: RMSNormNoScale
400+
@ModuleInfo(key: "proj") var proj: Linear
401+
@ModuleInfo(key: "scale") var scale: MLXArray
402+
@ModuleInfo(key: "per_expert_scale") var perExpertScale: MLXArray
403+
404+
init(_ config: Gemma4TextConfiguration) {
405+
guard let numExperts = config.numExperts, let topKExperts = config.topKExperts else {
406+
fatalError("Gemma4 MoE router requires numExperts and topKExperts")
407+
}
408+
409+
self.topKExperts = topKExperts
410+
self.rootSize = pow(Float(config.hiddenSize), -0.5)
411+
412+
self._norm.wrappedValue = RMSNormNoScale(eps: config.rmsNormEps)
413+
self._proj.wrappedValue = Linear(config.hiddenSize, numExperts, bias: false)
414+
self._scale.wrappedValue = MLXArray.ones([config.hiddenSize])
415+
self._perExpertScale.wrappedValue = MLXArray.ones([numExperts])
416+
super.init()
417+
}
418+
419+
func callAsFunction(_ x: MLXArray) -> (MLXArray, MLXArray) {
420+
var x = norm(x)
421+
x = x * MLXArray(rootSize, dtype: x.dtype)
422+
x = x * scale.asType(x.dtype)
423+
424+
let expertScores = proj(x)
425+
let routerProbabilities = MLX.softmax(expertScores, axis: -1, precise: true)
426+
427+
let topKIndices = MLX.argPartition(-expertScores, kth: topKExperts - 1, axis: -1)[
428+
.ellipsis, ..<topKExperts,
429+
]
430+
var topKWeights = MLX.takeAlong(routerProbabilities, topKIndices, axis: -1)
431+
topKWeights = topKWeights / MLX.sum(topKWeights, axis: -1, keepDims: true)
432+
topKWeights = topKWeights * perExpertScale[topKIndices].asType(topKWeights.dtype)
433+
return (topKIndices, topKWeights)
434+
}
435+
}
436+
437+
// MARK: - MoE Experts
438+
439+
private class Gemma4TextExperts: Module {
440+
@ModuleInfo(key: "switch_glu") var switchGLU: SwitchGLU
441+
442+
init(_ config: Gemma4TextConfiguration) {
443+
guard let numExperts = config.numExperts,
444+
let moeIntermediateSize = config.moeIntermediateSize
445+
else {
446+
fatalError("Gemma4 MoE experts require numExperts and moeIntermediateSize")
447+
}
448+
449+
self._switchGLU.wrappedValue = SwitchGLU(
450+
inputDims: config.hiddenSize,
451+
hiddenDims: moeIntermediateSize,
452+
numExperts: numExperts,
453+
activation: geluApproximate,
454+
bias: false
455+
)
456+
super.init()
457+
}
458+
459+
func callAsFunction(
460+
_ x: MLXArray, topKIndices: MLXArray, topKWeights: MLXArray
461+
) -> MLXArray {
462+
let batch = x.dim(0)
463+
let length = x.dim(1)
464+
let hidden = x.dim(2)
465+
let topK = topKIndices.dim(-1)
466+
467+
let expertOutput = switchGLU(
468+
x.reshaped(batch * length, hidden),
469+
topKIndices.reshaped(batch * length, topK)
470+
)
471+
let weights = topKWeights.reshaped(batch * length, topK, 1).asType(expertOutput.dtype)
472+
return (expertOutput * weights).sum(axis: -2).reshaped(batch, length, hidden)
473+
}
474+
}
475+
377476
// MARK: - Decoder Layer
378477

379478
private class Gemma4DecoderLayer: Module {
@@ -388,6 +487,11 @@ private class Gemma4DecoderLayer: Module {
388487
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm
389488
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayernorm: RMSNorm
390489
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayernorm: RMSNorm
490+
@ModuleInfo(key: "router") var router: Gemma4TextRouter?
491+
@ModuleInfo(key: "experts") var experts: Gemma4TextExperts?
492+
@ModuleInfo(key: "post_feedforward_layernorm_1") var postFeedforwardLayernorm1: RMSNorm?
493+
@ModuleInfo(key: "post_feedforward_layernorm_2") var postFeedforwardLayernorm2: RMSNorm?
494+
@ModuleInfo(key: "pre_feedforward_layernorm_2") var preFeedforwardLayernorm2: RMSNorm?
391495

392496
// Per-layer input (PLE) gating
393497
@ModuleInfo(key: "per_layer_input_gate") var perLayerInputGate: Linear?
@@ -415,6 +519,17 @@ private class Gemma4DecoderLayer: Module {
415519
self._postFeedforwardLayernorm.wrappedValue = RMSNorm(
416520
dimensions: config.hiddenSize, eps: config.rmsNormEps)
417521

522+
if config.enableMoEBlock {
523+
self._router.wrappedValue = Gemma4TextRouter(config)
524+
self._experts.wrappedValue = Gemma4TextExperts(config)
525+
self._postFeedforwardLayernorm1.wrappedValue = RMSNorm(
526+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
527+
self._postFeedforwardLayernorm2.wrappedValue = RMSNorm(
528+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
529+
self._preFeedforwardLayernorm2.wrappedValue = RMSNorm(
530+
dimensions: config.hiddenSize, eps: config.rmsNormEps)
531+
}
532+
418533
if hiddenSizePerLayerInput > 0 {
419534
self._perLayerInputGate.wrappedValue = Linear(
420535
config.hiddenSize, hiddenSizePerLayerInput, bias: false)
@@ -446,8 +561,26 @@ private class Gemma4DecoderLayer: Module {
446561
var out = residual + postAttn
447562

448563
let residual2 = out
449-
out = preFeedforwardLayernorm(out)
450-
out = mlp(out)
564+
if let router, let experts,
565+
let postFeedforwardLayernorm1,
566+
let postFeedforwardLayernorm2,
567+
let preFeedforwardLayernorm2
568+
{
569+
// MoE: dual dense + sparse feedforward
570+
var dense = preFeedforwardLayernorm(out)
571+
dense = mlp(dense)
572+
dense = postFeedforwardLayernorm1(dense)
573+
574+
let (topKIndices, topKWeights) = router(out)
575+
var sparse = preFeedforwardLayernorm2(out)
576+
sparse = experts(sparse, topKIndices: topKIndices, topKWeights: topKWeights)
577+
sparse = postFeedforwardLayernorm2(sparse)
578+
579+
out = dense + sparse
580+
} else {
581+
out = preFeedforwardLayernorm(out)
582+
out = mlp(out)
583+
}
451584
out = postFeedforwardLayernorm(out)
452585
out = residual2 + out
453586

@@ -675,6 +808,34 @@ public class Gemma4TextModel: Module, LLMModel, KVCacheDimensionProvider {
675808
{
676809
continue
677810
}
811+
812+
// MoE expert weight remapping: fused HF tensors → SwitchGLU layout
813+
if k.hasSuffix(".experts.down_proj") {
814+
sanitized[
815+
k.replacingOccurrences(
816+
of: ".experts.down_proj",
817+
with: ".experts.switch_glu.down_proj.weight"
818+
)
819+
] = v
820+
continue
821+
}
822+
if k.hasSuffix(".experts.gate_up_proj") {
823+
let mid = v.dim(-2) / 2
824+
sanitized[
825+
k.replacingOccurrences(
826+
of: ".experts.gate_up_proj",
827+
with: ".experts.switch_glu.gate_proj.weight"
828+
)
829+
] = v[.ellipsis, ..<mid, 0...]
830+
sanitized[
831+
k.replacingOccurrences(
832+
of: ".experts.gate_up_proj",
833+
with: ".experts.switch_glu.up_proj.weight"
834+
)
835+
] = v[.ellipsis, mid..., 0...]
836+
continue
837+
}
838+
678839
sanitized[k] = v
679840
}
680841
return sanitized

Tests/MLXLMTests/Gemma4Tests.swift

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,59 @@ extension MLXTestingSuite {
9797
#expect(!sum.isNaN)
9898
#expect(!sum.isInfinite)
9999
}
100+
101+
/// Create a minimal test configuration for Gemma 4 Text MoE
102+
private func makeTinyTextMoEConfigData() -> Data {
103+
let json = """
104+
{
105+
"model_type": "gemma4_text",
106+
"hidden_size": 64,
107+
"num_hidden_layers": 2,
108+
"intermediate_size": 128,
109+
"num_attention_heads": 4,
110+
"head_dim": 16,
111+
"global_head_dim": 64,
112+
"rms_norm_eps": 1e-6,
113+
"vocab_size": 100,
114+
"num_key_value_heads": 2,
115+
"rope_traditional": false,
116+
"sliding_window": 128,
117+
"sliding_window_pattern": 1,
118+
"max_position_embeddings": 512,
119+
"num_kv_shared_layers": 0,
120+
"use_double_wide_mlp": false,
121+
"tie_word_embeddings": true,
122+
"hidden_size_per_layer_input": 32,
123+
"vocab_size_per_layer_input": 10,
124+
"final_logit_softcapping": 30.0,
125+
"enable_moe_block": true,
126+
"num_experts": 4,
127+
"top_k_experts": 2,
128+
"moe_intermediate_size": 128,
129+
"attention_k_eq_v": false
130+
}
131+
"""
132+
return json.data(using: .utf8)!
133+
}
134+
135+
@Test("Gemma 4 Text MoE Instantiation & Forward Pass")
136+
func testGemma4TextMoEInstantiationAndForward() throws {
137+
let data = makeTinyTextMoEConfigData()
138+
let config = try JSONDecoder().decode(Gemma4TextConfiguration.self, from: data)
139+
let model = Gemma4TextModel(config)
140+
#expect(model.vocabularySize == 100)
141+
142+
// This validates that the conditional MoE logic and SwitchGLU layer initialize properly
143+
// without crashing, proving we correctly load gemma4_text active MoEs.
144+
let input = MLXArray(0..<8).reshaped(1, 8)
145+
let output = model(input, cache: nil)
146+
147+
// Ensure dimensionality is correct
148+
#expect(output.shape == [1, 8, model.vocabularySize])
149+
150+
let sum = output.sum().item(Float.self)
151+
#expect(!sum.isNaN)
152+
#expect(!sum.isInfinite)
153+
}
100154
}
101155
}

0 commit comments

Comments
 (0)