@@ -38,6 +38,10 @@ public struct Qwen35Configuration: Codable, Sendable {
3838public class Qwen35MoEModel : Qwen35Model {
3939
4040 override public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
41+ // ── Step 1: FP8 dequantization (official Qwen3.6-35B-A3B-FP8 checkpoint) ──
42+ // The FP8 release stores quantized weights alongside weight_scale_inv tensors.
43+ // We preserve them and stack them so they can be lazily dequantized in SwitchLinear.
44+ // ── Step 2: Key remapping ──
4145 var newWeights = [ String: MLXArray] ( )
4246 for (key, value) in weights {
4347 if key. hasPrefix ( " vision_tower " ) || key. hasPrefix ( " model.visual " ) {
@@ -53,45 +57,165 @@ public class Qwen35MoEModel: Qwen35Model {
5357 newWeights [ key] = value
5458 }
5559
60+ // ── Step 3: MoE expert weight stacking (main layers) ──
61+ // Format A: community 4-bit checkpoints ship a pre-stacked "gate_up_proj" → split into gate/up
62+ // Format B: FP8/BF16 official checkpoints ship per-expert "experts.N.{gate,up,down}_proj" → stack
63+ let nExperts = languageModel. configuration. numExperts
5664 for l in 0 ..< languageModel. configuration. hiddenLayers {
5765 let prefix = " language_model.model.layers. \( l) .mlp "
66+
67+ // Format A
5868 let gateUpKey = " \( prefix) .experts.gate_up_proj "
5969 if let gateUp = newWeights [ gateUpKey] {
6070 newWeights [ gateUpKey] = nil
6171 let mid = gateUp. dim ( - 2 ) / 2
62- newWeights [ " \( prefix) .switch_mlp.gate_proj.weight " ] =
63- gateUp [ . ellipsis, ..< mid, 0 ... ]
64- newWeights [ " \( prefix) .switch_mlp.up_proj.weight " ] =
65- gateUp [ . ellipsis, mid... , 0 ... ]
66- if let downProj = newWeights [ " \( prefix) .experts.down_proj " ] {
72+ newWeights [ " \( prefix) .switch_mlp.gate_proj.weight " ] = gateUp [ . ellipsis, ..< mid, 0 ... ]
73+ newWeights [ " \( prefix) .switch_mlp.up_proj.weight " ] = gateUp [ . ellipsis, mid... , 0 ... ]
74+ if let dp = newWeights [ " \( prefix) .experts.down_proj " ] {
6775 newWeights [ " \( prefix) .experts.down_proj " ] = nil
68- newWeights [ " \( prefix) .switch_mlp.down_proj.weight " ] = downProj
76+ newWeights [ " \( prefix) .switch_mlp.down_proj.weight " ] = dp
77+ }
78+ }
79+
80+ // Format B
81+ if newWeights [ " \( prefix) .experts.0.gate_proj.weight " ] != nil {
82+ for projName in [ " gate_proj " , " up_proj " , " down_proj " ] {
83+ let perExpert = ( 0 ..< nExperts) . compactMap {
84+ newWeights [ " \( prefix) .experts. \( $0) . \( projName) .weight " ]
85+ }
86+ let perExpertScale = ( 0 ..< nExperts) . compactMap {
87+ newWeights [ " \( prefix) .experts. \( $0) . \( projName) .weight_scale_inv " ]
88+ }
89+
90+ if perExpert. count == nExperts {
91+ if perExpertScale. count == nExperts {
92+ // FP8 checkpoint: eager per-expert dequant at load time.
93+ // Avoids re-running fromFp8 + block-scale on the full [256,outDim,inDim]
94+ // stacked tensor on every forward pass (would be prohibitively slow).
95+ let bs = 128
96+ let dequanted : [ MLXArray ] = zip ( perExpert, perExpertScale) . map { w, inv in
97+ let wFp = MLXFast . fromFp8 ( w, dtype: . bfloat16)
98+ let ( m, n) = ( wFp. dim ( 0 ) , wFp. dim ( 1 ) )
99+ let padB = ( bs - m % bs) % bs
100+ let padS = ( bs - n % bs) % bs
101+ var p = MLX . padded ( wFp, widths: [ [ 0 , padB] , [ 0 , padS] ] )
102+ p = p. reshaped ( [ ( m + padB) / bs, bs, ( n + padS) / bs, bs] )
103+ let scaled = p * inv[ 0 ... , . newAxis, 0 ... , . newAxis]
104+ return scaled. reshaped ( [ m + padB, n + padS] ) [ 0 ..< m, 0 ..< n] . asType ( . bfloat16)
105+ }
106+ let stacked = MLX . stacked ( dequanted)
107+ // Eagerly eval to pay the dequant cost at load time, not during prefill.
108+ // Without this, the entire lazy graph materializes on first forward pass.
109+ MLX . eval ( stacked)
110+ newWeights [ " \( prefix) .switch_mlp. \( projName) .weight " ] = stacked
111+ // Scale tensors consumed — do NOT store weight_scale_inv
112+ for i in 0 ..< nExperts {
113+ newWeights. removeValue ( forKey: " \( prefix) .experts. \( i) . \( projName) .weight " )
114+ newWeights. removeValue ( forKey: " \( prefix) .experts. \( i) . \( projName) .weight_scale_inv " )
115+ }
116+ } else {
117+ // BF16 checkpoint: stack as-is
118+ newWeights [ " \( prefix) .switch_mlp. \( projName) .weight " ] = MLX . stacked ( perExpert)
119+ for i in 0 ..< nExperts {
120+ newWeights. removeValue ( forKey: " \( prefix) .experts. \( i) . \( projName) .weight " )
121+ }
122+ }
123+ }
69124 }
70125 }
71126 }
72-
127+
128+ // ── Step 4: MoE expert weight stacking (MTP heads) ──
73129 for l in 0 ..< languageModel. configuration. numNextnPredictLayers {
74130 let prefixes = [
75131 " language_model.mtp. \( l) .layers.0.mlp " ,
76- " language_model.mtp.layers.0.mlp "
132+ " language_model.mtp.layers.0.mlp " ,
133+ " language_model.mtp.layers. \( l) .mlp "
77134 ]
78135 for prefix in prefixes {
136+ // Format A
79137 let gateUpKey = " \( prefix) .experts.gate_up_proj "
80138 if let gateUp = newWeights [ gateUpKey] {
81139 newWeights [ gateUpKey] = nil
82140 let mid = gateUp. dim ( - 2 ) / 2
83- newWeights [ " \( prefix) .switch_mlp.gate_proj.weight " ] =
84- gateUp [ . ellipsis, ..< mid, 0 ... ]
85- newWeights [ " \( prefix) .switch_mlp.up_proj.weight " ] =
86- gateUp [ . ellipsis, mid... , 0 ... ]
87- if let downProj = newWeights [ " \( prefix) .experts.down_proj " ] {
141+ newWeights [ " \( prefix) .switch_mlp.gate_proj.weight " ] = gateUp [ . ellipsis, ..< mid, 0 ... ]
142+ newWeights [ " \( prefix) .switch_mlp.up_proj.weight " ] = gateUp [ . ellipsis, mid... , 0 ... ]
143+ if let dp = newWeights [ " \( prefix) .experts.down_proj " ] {
88144 newWeights [ " \( prefix) .experts.down_proj " ] = nil
89- newWeights [ " \( prefix) .switch_mlp.down_proj.weight " ] = downProj
145+ newWeights [ " \( prefix) .switch_mlp.down_proj.weight " ] = dp
146+ }
147+ }
148+
149+ // Format B
150+ if newWeights [ " \( prefix) .experts.0.gate_proj.weight " ] != nil {
151+ for projName in [ " gate_proj " , " up_proj " , " down_proj " ] {
152+ let perExpert = ( 0 ..< nExperts) . compactMap {
153+ newWeights [ " \( prefix) .experts. \( $0) . \( projName) .weight " ]
154+ }
155+ let perExpertScale = ( 0 ..< nExperts) . compactMap {
156+ newWeights [ " \( prefix) .experts. \( $0) . \( projName) .weight_scale_inv " ]
157+ }
158+ if perExpert. count == nExperts {
159+ if perExpertScale. count == nExperts {
160+ let bs = 128
161+ let dequanted : [ MLXArray ] = zip ( perExpert, perExpertScale) . map { w, inv in
162+ let wFp = MLXFast . fromFp8 ( w, dtype: . bfloat16)
163+ let ( m, n) = ( wFp. dim ( 0 ) , wFp. dim ( 1 ) )
164+ let padB = ( bs - m % bs) % bs; let padS = ( bs - n % bs) % bs
165+ var p = MLX . padded ( wFp, widths: [ [ 0 , padB] , [ 0 , padS] ] )
166+ p = p. reshaped ( [ ( m + padB) / bs, bs, ( n + padS) / bs, bs] )
167+ return ( p * inv[ 0 ... , . newAxis, 0 ... , . newAxis] ) . reshaped ( [ m + padB, n + padS] ) [ 0 ..< m, 0 ..< n] . asType ( . bfloat16)
168+ }
169+ let stacked = MLX . stacked ( dequanted)
170+ MLX . eval ( stacked)
171+ newWeights [ " \( prefix) .switch_mlp. \( projName) .weight " ] = stacked
172+ for i in 0 ..< nExperts {
173+ newWeights. removeValue ( forKey: " \( prefix) .experts. \( i) . \( projName) .weight " )
174+ newWeights. removeValue ( forKey: " \( prefix) .experts. \( i) . \( projName) .weight_scale_inv " )
175+ }
176+ } else {
177+ newWeights [ " \( prefix) .switch_mlp. \( projName) .weight " ] = MLX . stacked ( perExpert)
178+ for i in 0 ..< nExperts {
179+ newWeights. removeValue ( forKey: " \( prefix) .experts. \( i) . \( projName) .weight " )
180+ }
181+ }
182+ }
90183 }
91184 }
92185 }
93186 }
94187
188+ // ── Step 5: Eager FP8 block-wise dequantization for remaining non-expert Linear layers ──
189+ // After Steps 3+4, ALL switch_mlp expert scale tensors have been consumed during stacking.
190+ // Any remaining "weight_scale_inv" keys belong to regular Linear layers
191+ // (attention projections, shared_expert, GatedDeltaNet, lm_head, etc.).
192+ // These cannot carry weight_scale_inv, so we eagerly dequantize here.
193+ var processed = [ String: MLXArray] ( )
194+ for (key, value) in newWeights {
195+ if key. hasSuffix ( " .weight_scale_inv " ) {
196+ let wKey = key. replacingOccurrences ( of: " _scale_inv " , with: " " )
197+ if let w = newWeights [ wKey] , processed [ wKey] == nil {
198+ // Swift MLX maps F8_E4M3 → uint8; fromFp8 gives proper signed floats.
199+ let wFp : MLXArray = MLXFast . fromFp8 ( w, dtype: . bfloat16)
200+ let bs = 128
201+ let ( m, n) = ( wFp. dim ( 0 ) , wFp. dim ( 1 ) )
202+ let padBottom = ( bs - m % bs) % bs
203+ let padSide = ( bs - n % bs) % bs
204+ var padded = MLX . padded ( wFp, widths: [ [ 0 , padBottom] , [ 0 , padSide] ] )
205+ padded = padded. reshaped ( [ ( m + padBottom) / bs, bs, ( n + padSide) / bs, bs] )
206+ let scaled = padded * value[ 0 ... , . newAxis, 0 ... , . newAxis]
207+ let dequant = scaled. reshaped ( [ m + padBottom, n + padSide] ) [ 0 ..< m, 0 ..< n]
208+ processed [ wKey] = dequant. asType ( . bfloat16)
209+ }
210+ // Drop the scale tensor — Linear has no slot for it.
211+ } else if processed [ key] == nil {
212+ processed [ key] = value
213+ }
214+ }
215+ if !processed. isEmpty { newWeights = processed }
216+
217+
95218 return languageModel. sanitize ( weights: newWeights)
96219 }
220+
97221}
0 commit comments