@@ -33,6 +33,10 @@ public struct Gemma4TextConfiguration: Codable, Sendable {
3333 var attentionKeqV : Bool = false
3434 var finalLogitSoftcapping : Float = 30.0
3535 var useDoubleWideMlp : Bool = true
36+ var enableMoEBlock : Bool = false
37+ var numExperts : Int ?
38+ var topKExperts : Int ?
39+ var moeIntermediateSize : Int ?
3640 var layerTypes : [ String ] = [ ]
3741 var tieWordEmbeddings : Bool = true
3842
@@ -66,6 +70,10 @@ public struct Gemma4TextConfiguration: Codable, Sendable {
6670 case attentionKeqV = " attention_k_eq_v "
6771 case finalLogitSoftcapping = " final_logit_softcapping "
6872 case useDoubleWideMlp = " use_double_wide_mlp "
73+ case enableMoEBlock = " enable_moe_block "
74+ case numExperts = " num_experts "
75+ case topKExperts = " top_k_experts "
76+ case moeIntermediateSize = " moe_intermediate_size "
6977 case layerTypes = " layer_types "
7078 case tieWordEmbeddings = " tie_word_embeddings "
7179 case ropeParameters = " rope_parameters "
@@ -110,6 +118,14 @@ public struct Gemma4TextConfiguration: Codable, Sendable {
110118 try container. decodeIfPresent ( Float . self, forKey: . finalLogitSoftcapping) ?? 30.0
111119 self . useDoubleWideMlp =
112120 try container. decodeIfPresent ( Bool . self, forKey: . useDoubleWideMlp) ?? true
121+ self . enableMoEBlock =
122+ try container. decodeIfPresent ( Bool . self, forKey: . enableMoEBlock) ?? false
123+ self . numExperts =
124+ try container. decodeIfPresent ( Int . self, forKey: . numExperts)
125+ self . topKExperts =
126+ try container. decodeIfPresent ( Int . self, forKey: . topKExperts)
127+ self . moeIntermediateSize =
128+ try container. decodeIfPresent ( Int . self, forKey: . moeIntermediateSize)
113129 if let decoded = try container. decodeIfPresent ( [ String ] . self, forKey: . layerTypes) {
114130 self . layerTypes = decoded
115131 } else {
@@ -374,6 +390,89 @@ private class Gemma4MLP: Module {
374390 }
375391}
376392
393+ // MARK: - MoE Router
394+
395+ private class Gemma4TextRouter : Module {
396+ let topKExperts : Int
397+ let rootSize : Float
398+
399+ @ModuleInfo ( key: " norm " ) var norm : RMSNormNoScale
400+ @ModuleInfo ( key: " proj " ) var proj : Linear
401+ @ModuleInfo ( key: " scale " ) var scale : MLXArray
402+ @ModuleInfo ( key: " per_expert_scale " ) var perExpertScale : MLXArray
403+
404+ init ( _ config: Gemma4TextConfiguration ) {
405+ guard let numExperts = config. numExperts, let topKExperts = config. topKExperts else {
406+ fatalError ( " Gemma4 MoE router requires numExperts and topKExperts " )
407+ }
408+
409+ self . topKExperts = topKExperts
410+ self . rootSize = pow ( Float ( config. hiddenSize) , - 0.5 )
411+
412+ self . _norm. wrappedValue = RMSNormNoScale ( eps: config. rmsNormEps)
413+ self . _proj. wrappedValue = Linear ( config. hiddenSize, numExperts, bias: false )
414+ self . _scale. wrappedValue = MLXArray . ones ( [ config. hiddenSize] )
415+ self . _perExpertScale. wrappedValue = MLXArray . ones ( [ numExperts] )
416+ super. init ( )
417+ }
418+
419+ func callAsFunction( _ x: MLXArray ) -> ( MLXArray , MLXArray ) {
420+ var x = norm ( x)
421+ x = x * MLXArray( rootSize, dtype: x. dtype)
422+ x = x * scale. asType ( x. dtype)
423+
424+ let expertScores = proj ( x)
425+ let routerProbabilities = MLX . softmax ( expertScores, axis: - 1 , precise: true )
426+
427+ let topKIndices = MLX . argPartition ( - expertScores, kth: topKExperts - 1 , axis: - 1 ) [
428+ . ellipsis, ..< topKExperts,
429+ ]
430+ var topKWeights = MLX . takeAlong ( routerProbabilities, topKIndices, axis: - 1 )
431+ topKWeights = topKWeights / MLX. sum ( topKWeights, axis: - 1 , keepDims: true )
432+ topKWeights = topKWeights * perExpertScale[ topKIndices] . asType ( topKWeights. dtype)
433+ return ( topKIndices, topKWeights)
434+ }
435+ }
436+
437+ // MARK: - MoE Experts
438+
439+ private class Gemma4TextExperts : Module {
440+ @ModuleInfo ( key: " switch_glu " ) var switchGLU : SwitchGLU
441+
442+ init ( _ config: Gemma4TextConfiguration ) {
443+ guard let numExperts = config. numExperts,
444+ let moeIntermediateSize = config. moeIntermediateSize
445+ else {
446+ fatalError ( " Gemma4 MoE experts require numExperts and moeIntermediateSize " )
447+ }
448+
449+ self . _switchGLU. wrappedValue = SwitchGLU (
450+ inputDims: config. hiddenSize,
451+ hiddenDims: moeIntermediateSize,
452+ numExperts: numExperts,
453+ activation: geluApproximate,
454+ bias: false
455+ )
456+ super. init ( )
457+ }
458+
459+ func callAsFunction(
460+ _ x: MLXArray , topKIndices: MLXArray , topKWeights: MLXArray
461+ ) -> MLXArray {
462+ let batch = x. dim ( 0 )
463+ let length = x. dim ( 1 )
464+ let hidden = x. dim ( 2 )
465+ let topK = topKIndices. dim ( - 1 )
466+
467+ let expertOutput = switchGLU (
468+ x. reshaped ( batch * length, hidden) ,
469+ topKIndices. reshaped ( batch * length, topK)
470+ )
471+ let weights = topKWeights. reshaped ( batch * length, topK, 1 ) . asType ( expertOutput. dtype)
472+ return ( expertOutput * weights) . sum ( axis: - 2 ) . reshaped ( batch, length, hidden)
473+ }
474+ }
475+
377476// MARK: - Decoder Layer
378477
379478private class Gemma4DecoderLayer : Module {
@@ -388,6 +487,11 @@ private class Gemma4DecoderLayer: Module {
388487 @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayernorm : RMSNorm
389488 @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayernorm : RMSNorm
390489 @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayernorm : RMSNorm
490+ @ModuleInfo ( key: " router " ) var router : Gemma4TextRouter ?
491+ @ModuleInfo ( key: " experts " ) var experts : Gemma4TextExperts ?
492+ @ModuleInfo ( key: " post_feedforward_layernorm_1 " ) var postFeedforwardLayernorm1 : RMSNorm ?
493+ @ModuleInfo ( key: " post_feedforward_layernorm_2 " ) var postFeedforwardLayernorm2 : RMSNorm ?
494+ @ModuleInfo ( key: " pre_feedforward_layernorm_2 " ) var preFeedforwardLayernorm2 : RMSNorm ?
391495
392496 // Per-layer input (PLE) gating
393497 @ModuleInfo ( key: " per_layer_input_gate " ) var perLayerInputGate : Linear ?
@@ -415,6 +519,17 @@ private class Gemma4DecoderLayer: Module {
415519 self . _postFeedforwardLayernorm. wrappedValue = RMSNorm (
416520 dimensions: config. hiddenSize, eps: config. rmsNormEps)
417521
522+ if config. enableMoEBlock {
523+ self . _router. wrappedValue = Gemma4TextRouter ( config)
524+ self . _experts. wrappedValue = Gemma4TextExperts ( config)
525+ self . _postFeedforwardLayernorm1. wrappedValue = RMSNorm (
526+ dimensions: config. hiddenSize, eps: config. rmsNormEps)
527+ self . _postFeedforwardLayernorm2. wrappedValue = RMSNorm (
528+ dimensions: config. hiddenSize, eps: config. rmsNormEps)
529+ self . _preFeedforwardLayernorm2. wrappedValue = RMSNorm (
530+ dimensions: config. hiddenSize, eps: config. rmsNormEps)
531+ }
532+
418533 if hiddenSizePerLayerInput > 0 {
419534 self . _perLayerInputGate. wrappedValue = Linear (
420535 config. hiddenSize, hiddenSizePerLayerInput, bias: false )
@@ -446,8 +561,26 @@ private class Gemma4DecoderLayer: Module {
446561 var out = residual + postAttn
447562
448563 let residual2 = out
449- out = preFeedforwardLayernorm ( out)
450- out = mlp ( out)
564+ if let router, let experts,
565+ let postFeedforwardLayernorm1,
566+ let postFeedforwardLayernorm2,
567+ let preFeedforwardLayernorm2
568+ {
569+ // MoE: dual dense + sparse feedforward
570+ var dense = preFeedforwardLayernorm ( out)
571+ dense = mlp ( dense)
572+ dense = postFeedforwardLayernorm1 ( dense)
573+
574+ let ( topKIndices, topKWeights) = router ( out)
575+ var sparse = preFeedforwardLayernorm2 ( out)
576+ sparse = experts ( sparse, topKIndices: topKIndices, topKWeights: topKWeights)
577+ sparse = postFeedforwardLayernorm2 ( sparse)
578+
579+ out = dense + sparse
580+ } else {
581+ out = preFeedforwardLayernorm ( out)
582+ out = mlp ( out)
583+ }
451584 out = postFeedforwardLayernorm ( out)
452585 out = residual2 + out
453586
@@ -675,6 +808,34 @@ public class Gemma4TextModel: Module, LLMModel, KVCacheDimensionProvider {
675808 {
676809 continue
677810 }
811+
812+ // MoE expert weight remapping: fused HF tensors → SwitchGLU layout
813+ if k. hasSuffix ( " .experts.down_proj " ) {
814+ sanitized [
815+ k. replacingOccurrences (
816+ of: " .experts.down_proj " ,
817+ with: " .experts.switch_glu.down_proj.weight "
818+ )
819+ ] = v
820+ continue
821+ }
822+ if k. hasSuffix ( " .experts.gate_up_proj " ) {
823+ let mid = v. dim ( - 2 ) / 2
824+ sanitized [
825+ k. replacingOccurrences (
826+ of: " .experts.gate_up_proj " ,
827+ with: " .experts.switch_glu.gate_proj.weight "
828+ )
829+ ] = v [ . ellipsis, ..< mid, 0 ... ]
830+ sanitized [
831+ k. replacingOccurrences (
832+ of: " .experts.gate_up_proj " ,
833+ with: " .experts.switch_glu.up_proj.weight "
834+ )
835+ ] = v [ . ellipsis, mid... , 0 ... ]
836+ continue
837+ }
838+
678839 sanitized [ k] = v
679840 }
680841 return sanitized
0 commit comments