Skip to content

Commit 6a57d00

Browse files
author
Aegis-AI
committed
fix(moe): address Copilot review comments for stacked MoE fast path
- Compute and use separate `_stackedDownBytesPerExpert` for the down projection so it doesn't incorrectly reuse the gate/up stride. - Fix docstring for `computeExpertsFused` to refer generically to outputDims/inputDims instead of intermediate/hidden. - Add `StackedMoETests.swift` unit test to verify the fast path cleanly falls back without crashing when enabled on non-quantized models.
1 parent 86c9307 commit 6a57d00

2 files changed

Lines changed: 84 additions & 8 deletions

File tree

Libraries/MLXLMCommon/SwitchLayers.swift

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
102102
private var _tokenCounter: Int = 0
103103
// Bytes per expert slab in a stacked buffer; computed once on cold init.
104104
private var _stackedBytesPerExpert: Int = 0
105+
private var _stackedDownBytesPerExpert: Int = 0
105106

106107
// ── Fused gate+up SwiGLU mode (env-gated MLX_MOE_FUSE_GATEUP=1) ──
107108
// SwiGLU MLP is `silu(gate(x)) * up(x)`; gate and up are independent
@@ -204,7 +205,8 @@ public class SwitchGLU: Module, @unchecked Sendable {
204205
if let cb = _combinedGateUpBiases { coldEvalList.append(cb) }
205206
MLX.eval(coldEvalList)
206207
_stackedGateUpBytesPerProj = _stackedGateUp!.nbytes / CACHE_SLOTS / 2
207-
_stackedBytesPerExpert = _stackedGateUpBytesPerProj // shared with down
208+
_stackedBytesPerExpert = _stackedGateUpBytesPerProj
209+
_stackedDownBytesPerExpert = _stackedDown!.nbytes / CACHE_SLOTS
208210
} else {
209211
_stackedGate = MLXArray.zeros(
210212
[CACHE_SLOTS, qGate.weight.dim(1), qGate.weight.dim(2)]
@@ -220,6 +222,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
220222
_tokenCounter = 0
221223
MLX.eval([idx, _stackedGate!, _stackedUp!, _stackedDown!])
222224
_stackedBytesPerExpert = _stackedGate!.nbytes / CACHE_SLOTS
225+
_stackedDownBytesPerExpert = _stackedDown!.nbytes / CACHE_SLOTS
223226
}
224227
} else {
225228
// Warm path: kick off GPU work asynchronously while we
@@ -268,6 +271,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
268271
}
269272
if !specTargets.isEmpty {
270273
let bpe = _stackedBytesPerExpert
274+
let downBpe = _stackedDownBytesPerExpert
271275
DispatchQueue.concurrentPerform(iterations: specTargets.count * 3) { [specTargets] i in
272276
let mIdx = i / 3
273277
let proj = i % 3
@@ -295,7 +299,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
295299
}
296300
default:
297301
MLXFast.preadIntoOffset(self._stackedDown!, safetensorsPath: downSSD.path,
298-
tensorName: downSSD.tensorName, expertIndex: UInt32(info.expertId), dstOffset: info.slot * bpe)
302+
tensorName: downSSD.tensorName, expertIndex: UInt32(info.expertId), dstOffset: info.slot * downBpe)
299303
}
300304
}
301305
}
@@ -367,6 +371,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
367371
// ── Pread misses into stacked-buffer slots ──
368372
if !missesNeedingPread.isEmpty {
369373
let bpe = _stackedBytesPerExpert
374+
let downBpe = _stackedDownBytesPerExpert
370375
DispatchQueue.concurrentPerform(iterations: missesNeedingPread.count * 3) { [missesNeedingPread] i in
371376
let mIdx = i / 3
372377
let proj = i % 3
@@ -392,7 +397,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
392397
}
393398
default:
394399
MLXFast.preadIntoOffset(self._stackedDown!, safetensorsPath: downSSD.path,
395-
tensorName: downSSD.tensorName, expertIndex: UInt32(info.expertId), dstOffset: info.slot * bpe)
400+
tensorName: downSSD.tensorName, expertIndex: UInt32(info.expertId), dstOffset: info.slot * downBpe)
396401
}
397402
}
398403
}
@@ -1183,8 +1188,8 @@ public class QuantizedSwitchLinear: SwitchLinear, Quantized {
11831188
/// single dispatch over the full stacked weight buffer.
11841189
///
11851190
/// - Parameters:
1186-
/// - x: input activations, shape `[totalTokens, ..., hidden]`.
1187-
/// - stackedBuffer: weight buffer, shape `[CACHE_SLOTS, intermediate, hidden]`.
1191+
/// - x: input activations, shape `[totalTokens, ..., inputDims]`.
1192+
/// - stackedBuffer: weight buffer, shape `[CACHE_SLOTS, outputDims, inputDims]`.
11881193
/// Slots are populated externally via `MLXFast.preadIntoOffset`.
11891194
/// - slotPerToken: uint32 array mapping each token (along axis 0 of `x`)
11901195
/// to a slot index in `stackedBuffer`. Built from the routing.
@@ -1198,7 +1203,7 @@ public class QuantizedSwitchLinear: SwitchLinear, Quantized {
11981203
) -> MLXArray {
11991204
let slotExpertsMLX = MLXArray(slotExperts).asType(.uint32)
12001205
// Gather scales/biases for the experts currently in our slots.
1201-
// Result shape: [N_slots, intermediate, hidden / groupSize].
1206+
// Result shape: [N_slots, outputDims, inputDims / groupSize].
12021207
let stackedScales = MLX.take(self.scales, slotExpertsMLX, axis: 0)
12031208
var stackedBiases: MLXArray? = nil
12041209
if let b = self.biases { stackedBiases = MLX.take(b, slotExpertsMLX, axis: 0) }
@@ -1214,8 +1219,8 @@ public class QuantizedSwitchLinear: SwitchLinear, Quantized {
12141219

12151220
// Optional per-token bias add (gathered from per-slot bias).
12161221
if let bias = self.bias {
1217-
let stackedBias = MLX.take(bias, slotExpertsMLX, axis: 0) // [N_slots, intermediate]
1218-
let perTokenBias = MLX.take(stackedBias, slotPerToken, axis: 0) // [tokens, intermediate]
1222+
let stackedBias = MLX.take(bias, slotExpertsMLX, axis: 0) // [N_slots, outputDims]
1223+
let perTokenBias = MLX.take(stackedBias, slotPerToken, axis: 0) // [tokens, outputDims]
12191224
output = output + MLX.expandedDimensions(perTokenBias, axis: -2)
12201225
}
12211226

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import Foundation
2+
import MLX
3+
import MLXLLM
4+
import MLXLMCommon
5+
import MLXNN
6+
import Testing
7+
8+
@Suite
9+
struct StackedMoETests {
10+
11+
/// Create a minimal test configuration for Gemma 4 Text MoE
12+
private func makeTinyTextMoEConfigData() -> Data {
13+
let json = """
14+
{
15+
"model_type": "gemma4_text",
16+
"hidden_size": 64,
17+
"num_hidden_layers": 2,
18+
"intermediate_size": 128,
19+
"num_attention_heads": 4,
20+
"head_dim": 16,
21+
"global_head_dim": 64,
22+
"rms_norm_eps": 1e-6,
23+
"vocab_size": 100,
24+
"num_key_value_heads": 2,
25+
"rope_traditional": false,
26+
"sliding_window": 128,
27+
"sliding_window_pattern": 1,
28+
"max_position_embeddings": 512,
29+
"num_kv_shared_layers": 0,
30+
"use_double_wide_mlp": false,
31+
"tie_word_embeddings": true,
32+
"hidden_size_per_layer_input": 32,
33+
"vocab_size_per_layer_input": 10,
34+
"final_logit_softcapping": 30.0,
35+
"enable_moe_block": true,
36+
"num_experts": 4,
37+
"top_k_experts": 2,
38+
"moe_intermediate_size": 128,
39+
"attention_k_eq_v": false
40+
}
41+
"""
42+
return json.data(using: .utf8)!
43+
}
44+
45+
@Test("Stacked MoE fast path falls back for non-quantized models")
46+
func testStackedMoEFallback() throws {
47+
// Set env vars directly. Since tests run concurrently, this might affect others
48+
// if SwitchGLU is initialized here first, which is fine since the fallback is safe.
49+
setenv("MLX_MOE_STACKED", "1", 1)
50+
setenv("MLX_MOE_FUSE_GATEUP", "1", 1)
51+
defer {
52+
unsetenv("MLX_MOE_STACKED")
53+
unsetenv("MLX_MOE_FUSE_GATEUP")
54+
}
55+
56+
let data = makeTinyTextMoEConfigData()
57+
let config = try JSONDecoder().decode(Gemma4TextConfiguration.self, from: data)
58+
let model = Gemma4TextModel(config)
59+
60+
// This validates that the fast path falls back cleanly because
61+
// the weights are not quantized (they are standard MLXArray).
62+
let input = MLXArray(0..<8).reshaped(1, 8)
63+
let output = model(input, cache: nil)
64+
65+
#expect(output.shape == [1, 8, model.vocabularySize])
66+
67+
let sum = output.sum().item(Float.self)
68+
#expect(!sum.isNaN)
69+
#expect(!sum.isInfinite)
70+
}
71+
}

0 commit comments

Comments
 (0)