@@ -102,6 +102,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
102102 private var _tokenCounter : Int = 0
103103 // Bytes per expert slab in a stacked buffer; computed once on cold init.
104104 private var _stackedBytesPerExpert : Int = 0
105+ private var _stackedDownBytesPerExpert : Int = 0
105106
106107 // ── Fused gate+up SwiGLU mode (env-gated MLX_MOE_FUSE_GATEUP=1) ──
107108 // SwiGLU MLP is `silu(gate(x)) * up(x)`; gate and up are independent
@@ -204,7 +205,8 @@ public class SwitchGLU: Module, @unchecked Sendable {
204205 if let cb = _combinedGateUpBiases { coldEvalList. append ( cb) }
205206 MLX . eval ( coldEvalList)
206207 _stackedGateUpBytesPerProj = _stackedGateUp!. nbytes / CACHE_SLOTS / 2
207- _stackedBytesPerExpert = _stackedGateUpBytesPerProj // shared with down
208+ _stackedBytesPerExpert = _stackedGateUpBytesPerProj
209+ _stackedDownBytesPerExpert = _stackedDown!. nbytes / CACHE_SLOTS
208210 } else {
209211 _stackedGate = MLXArray . zeros (
210212 [ CACHE_SLOTS, qGate. weight. dim ( 1 ) , qGate. weight. dim ( 2 ) ]
@@ -220,6 +222,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
220222 _tokenCounter = 0
221223 MLX . eval ( [ idx, _stackedGate!, _stackedUp!, _stackedDown!] )
222224 _stackedBytesPerExpert = _stackedGate!. nbytes / CACHE_SLOTS
225+ _stackedDownBytesPerExpert = _stackedDown!. nbytes / CACHE_SLOTS
223226 }
224227 } else {
225228 // Warm path: kick off GPU work asynchronously while we
@@ -268,6 +271,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
268271 }
269272 if !specTargets. isEmpty {
270273 let bpe = _stackedBytesPerExpert
274+ let downBpe = _stackedDownBytesPerExpert
271275 DispatchQueue . concurrentPerform ( iterations: specTargets. count * 3 ) { [ specTargets] i in
272276 let mIdx = i / 3
273277 let proj = i % 3
@@ -295,7 +299,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
295299 }
296300 default :
297301 MLXFast . preadIntoOffset ( self . _stackedDown!, safetensorsPath: downSSD. path,
298- tensorName: downSSD. tensorName, expertIndex: UInt32 ( info. expertId) , dstOffset: info. slot * bpe )
302+ tensorName: downSSD. tensorName, expertIndex: UInt32 ( info. expertId) , dstOffset: info. slot * downBpe )
299303 }
300304 }
301305 }
@@ -367,6 +371,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
367371 // ── Pread misses into stacked-buffer slots ──
368372 if !missesNeedingPread. isEmpty {
369373 let bpe = _stackedBytesPerExpert
374+ let downBpe = _stackedDownBytesPerExpert
370375 DispatchQueue . concurrentPerform ( iterations: missesNeedingPread. count * 3 ) { [ missesNeedingPread] i in
371376 let mIdx = i / 3
372377 let proj = i % 3
@@ -392,7 +397,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
392397 }
393398 default :
394399 MLXFast . preadIntoOffset ( self . _stackedDown!, safetensorsPath: downSSD. path,
395- tensorName: downSSD. tensorName, expertIndex: UInt32 ( info. expertId) , dstOffset: info. slot * bpe )
400+ tensorName: downSSD. tensorName, expertIndex: UInt32 ( info. expertId) , dstOffset: info. slot * downBpe )
396401 }
397402 }
398403 }
@@ -1183,8 +1188,8 @@ public class QuantizedSwitchLinear: SwitchLinear, Quantized {
11831188 /// single dispatch over the full stacked weight buffer.
11841189 ///
11851190 /// - Parameters:
1186- /// - x: input activations, shape `[totalTokens, ..., hidden ]`.
1187- /// - stackedBuffer: weight buffer, shape `[CACHE_SLOTS, intermediate, hidden ]`.
1191+ /// - x: input activations, shape `[totalTokens, ..., inputDims ]`.
1192+ /// - stackedBuffer: weight buffer, shape `[CACHE_SLOTS, outputDims, inputDims ]`.
11881193 /// Slots are populated externally via `MLXFast.preadIntoOffset`.
11891194 /// - slotPerToken: uint32 array mapping each token (along axis 0 of `x`)
11901195 /// to a slot index in `stackedBuffer`. Built from the routing.
@@ -1198,7 +1203,7 @@ public class QuantizedSwitchLinear: SwitchLinear, Quantized {
11981203 ) -> MLXArray {
11991204 let slotExpertsMLX = MLXArray ( slotExperts) . asType ( . uint32)
12001205 // Gather scales/biases for the experts currently in our slots.
1201- // Result shape: [N_slots, intermediate, hidden / groupSize].
1206+ // Result shape: [N_slots, outputDims, inputDims / groupSize].
12021207 let stackedScales = MLX . take ( self . scales, slotExpertsMLX, axis: 0 )
12031208 var stackedBiases : MLXArray ? = nil
12041209 if let b = self . biases { stackedBiases = MLX . take ( b, slotExpertsMLX, axis: 0 ) }
@@ -1214,8 +1219,8 @@ public class QuantizedSwitchLinear: SwitchLinear, Quantized {
12141219
12151220 // Optional per-token bias add (gathered from per-slot bias).
12161221 if let bias = self . bias {
1217- let stackedBias = MLX . take ( bias, slotExpertsMLX, axis: 0 ) // [N_slots, intermediate ]
1218- let perTokenBias = MLX . take ( stackedBias, slotPerToken, axis: 0 ) // [tokens, intermediate ]
1222+ let stackedBias = MLX . take ( bias, slotExpertsMLX, axis: 0 ) // [N_slots, outputDims ]
1223+ let perTokenBias = MLX . take ( stackedBias, slotPerToken, axis: 0 ) // [tokens, outputDims ]
12191224 output = output + MLX. expandedDimensions ( perTokenBias, axis: - 2 )
12201225 }
12211226
0 commit comments