Skip to content

Commit cadd98a

Browse files
author
Aegis-AI
committed
fix(mtp): Resolve speculative decoding memory collapse and expand MoE support
- Force MLX evaluation of mtpLogits to prevent recursive compute graph explosion (OOM). - Apply dynamic KV cache quantization to all MTP draft heads during rewinds. - Optimize SwitchGLU SSD streaming with async pre-reads and fused buffers. - Add Qwen3.6-35B model definition and MTP speculation hooks.
1 parent 282b9a7 commit cadd98a

6 files changed

Lines changed: 398 additions & 106 deletions

File tree

Libraries/MLXLLM/Models/Qwen35.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,35 @@ public class Qwen35Model: Module, LLMModel, KVCacheDimensionProvider {
832832
sanitized[key] = value
833833
}
834834

835+
// FP8 block-wise dequantization for Qwen3.6-27B-FP8 (dense checkpoint).
836+
// Official FP8 checkpoints ship each weight tensor alongside a
837+
// "weight_scale_inv" tensor with shape [outFeatures/128, inFeatures/128].
838+
// We dequantize eagerly here (dense model fits in 64 GB without lazy streaming).
839+
var processed = [String: MLXArray]()
840+
for (key, value) in sanitized {
841+
if key.hasSuffix(".weight_scale_inv") {
842+
let wKey = key.replacingOccurrences(of: "_scale_inv", with: "")
843+
if let w = sanitized[wKey], processed[wKey] == nil {
844+
// Block-wise: scale_inv is [outBlocks, inBlocks], w is [outDim, inDim]
845+
// Swift MLX maps F8_E4M3 → uint8; fromFp8 gives the same signed
846+
// [-448,448] range that Python mx.load() produces automatically.
847+
let wFp: MLXArray = MLXFast.fromFp8(w, dtype: .bfloat16)
848+
let bs = 128
849+
let (m, n) = (wFp.dim(0), wFp.dim(1))
850+
let padBottom = (bs - m % bs) % bs
851+
let padSide = (bs - n % bs) % bs
852+
var padded = MLX.padded(wFp, widths: [[0, padBottom], [0, padSide]])
853+
padded = padded.reshaped([(m + padBottom) / bs, bs, (n + padSide) / bs, bs])
854+
let scaled = padded * value[0..., .newAxis, 0..., .newAxis]
855+
let dequant = scaled.reshaped([m + padBottom, n + padSide])[0 ..< m, 0 ..< n]
856+
processed[wKey] = dequant.asType(.bfloat16)
857+
}
858+
} else if processed[key] == nil {
859+
processed[key] = value
860+
}
861+
}
862+
if !processed.isEmpty { sanitized = processed }
863+
835864
return languageModel.sanitize(weights: sanitized)
836865
}
837866
}

Libraries/MLXLLM/Models/Qwen35MoE.swift

Lines changed: 138 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ public struct Qwen35Configuration: Codable, Sendable {
3838
public class Qwen35MoEModel: Qwen35Model {
3939

4040
override public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
41+
// ── Step 1: FP8 dequantization (official Qwen3.6-35B-A3B-FP8 checkpoint) ──
42+
// The FP8 release stores quantized weights alongside weight_scale_inv tensors.
43+
// We preserve them and stack them so they can be lazily dequantized in SwitchLinear.
44+
// ── Step 2: Key remapping ──
4145
var newWeights = [String: MLXArray]()
4246
for (key, value) in weights {
4347
if key.hasPrefix("vision_tower") || key.hasPrefix("model.visual") {
@@ -53,45 +57,165 @@ public class Qwen35MoEModel: Qwen35Model {
5357
newWeights[key] = value
5458
}
5559

60+
// ── Step 3: MoE expert weight stacking (main layers) ──
61+
// Format A: community 4-bit checkpoints ship a pre-stacked "gate_up_proj" → split into gate/up
62+
// Format B: FP8/BF16 official checkpoints ship per-expert "experts.N.{gate,up,down}_proj" → stack
63+
let nExperts = languageModel.configuration.numExperts
5664
for l in 0 ..< languageModel.configuration.hiddenLayers {
5765
let prefix = "language_model.model.layers.\(l).mlp"
66+
67+
// Format A
5868
let gateUpKey = "\(prefix).experts.gate_up_proj"
5969
if let gateUp = newWeights[gateUpKey] {
6070
newWeights[gateUpKey] = nil
6171
let mid = gateUp.dim(-2) / 2
62-
newWeights["\(prefix).switch_mlp.gate_proj.weight"] =
63-
gateUp[.ellipsis, ..<mid, 0...]
64-
newWeights["\(prefix).switch_mlp.up_proj.weight"] =
65-
gateUp[.ellipsis, mid..., 0...]
66-
if let downProj = newWeights["\(prefix).experts.down_proj"] {
72+
newWeights["\(prefix).switch_mlp.gate_proj.weight"] = gateUp[.ellipsis, ..<mid, 0...]
73+
newWeights["\(prefix).switch_mlp.up_proj.weight"] = gateUp[.ellipsis, mid..., 0...]
74+
if let dp = newWeights["\(prefix).experts.down_proj"] {
6775
newWeights["\(prefix).experts.down_proj"] = nil
68-
newWeights["\(prefix).switch_mlp.down_proj.weight"] = downProj
76+
newWeights["\(prefix).switch_mlp.down_proj.weight"] = dp
77+
}
78+
}
79+
80+
// Format B
81+
if newWeights["\(prefix).experts.0.gate_proj.weight"] != nil {
82+
for projName in ["gate_proj", "up_proj", "down_proj"] {
83+
let perExpert = (0 ..< nExperts).compactMap {
84+
newWeights["\(prefix).experts.\($0).\(projName).weight"]
85+
}
86+
let perExpertScale = (0 ..< nExperts).compactMap {
87+
newWeights["\(prefix).experts.\($0).\(projName).weight_scale_inv"]
88+
}
89+
90+
if perExpert.count == nExperts {
91+
if perExpertScale.count == nExperts {
92+
// FP8 checkpoint: eager per-expert dequant at load time.
93+
// Avoids re-running fromFp8 + block-scale on the full [256,outDim,inDim]
94+
// stacked tensor on every forward pass (would be prohibitively slow).
95+
let bs = 128
96+
let dequanted: [MLXArray] = zip(perExpert, perExpertScale).map { w, inv in
97+
let wFp = MLXFast.fromFp8(w, dtype: .bfloat16)
98+
let (m, n) = (wFp.dim(0), wFp.dim(1))
99+
let padB = (bs - m % bs) % bs
100+
let padS = (bs - n % bs) % bs
101+
var p = MLX.padded(wFp, widths: [[0, padB], [0, padS]])
102+
p = p.reshaped([(m + padB) / bs, bs, (n + padS) / bs, bs])
103+
let scaled = p * inv[0..., .newAxis, 0..., .newAxis]
104+
return scaled.reshaped([m + padB, n + padS])[0 ..< m, 0 ..< n].asType(.bfloat16)
105+
}
106+
let stacked = MLX.stacked(dequanted)
107+
// Eagerly eval to pay the dequant cost at load time, not during prefill.
108+
// Without this, the entire lazy graph materializes on first forward pass.
109+
MLX.eval(stacked)
110+
newWeights["\(prefix).switch_mlp.\(projName).weight"] = stacked
111+
// Scale tensors consumed — do NOT store weight_scale_inv
112+
for i in 0 ..< nExperts {
113+
newWeights.removeValue(forKey: "\(prefix).experts.\(i).\(projName).weight")
114+
newWeights.removeValue(forKey: "\(prefix).experts.\(i).\(projName).weight_scale_inv")
115+
}
116+
} else {
117+
// BF16 checkpoint: stack as-is
118+
newWeights["\(prefix).switch_mlp.\(projName).weight"] = MLX.stacked(perExpert)
119+
for i in 0 ..< nExperts {
120+
newWeights.removeValue(forKey: "\(prefix).experts.\(i).\(projName).weight")
121+
}
122+
}
123+
}
69124
}
70125
}
71126
}
72-
127+
128+
// ── Step 4: MoE expert weight stacking (MTP heads) ──
73129
for l in 0 ..< languageModel.configuration.numNextnPredictLayers {
74130
let prefixes = [
75131
"language_model.mtp.\(l).layers.0.mlp",
76-
"language_model.mtp.layers.0.mlp"
132+
"language_model.mtp.layers.0.mlp",
133+
"language_model.mtp.layers.\(l).mlp"
77134
]
78135
for prefix in prefixes {
136+
// Format A
79137
let gateUpKey = "\(prefix).experts.gate_up_proj"
80138
if let gateUp = newWeights[gateUpKey] {
81139
newWeights[gateUpKey] = nil
82140
let mid = gateUp.dim(-2) / 2
83-
newWeights["\(prefix).switch_mlp.gate_proj.weight"] =
84-
gateUp[.ellipsis, ..<mid, 0...]
85-
newWeights["\(prefix).switch_mlp.up_proj.weight"] =
86-
gateUp[.ellipsis, mid..., 0...]
87-
if let downProj = newWeights["\(prefix).experts.down_proj"] {
141+
newWeights["\(prefix).switch_mlp.gate_proj.weight"] = gateUp[.ellipsis, ..<mid, 0...]
142+
newWeights["\(prefix).switch_mlp.up_proj.weight"] = gateUp[.ellipsis, mid..., 0...]
143+
if let dp = newWeights["\(prefix).experts.down_proj"] {
88144
newWeights["\(prefix).experts.down_proj"] = nil
89-
newWeights["\(prefix).switch_mlp.down_proj.weight"] = downProj
145+
newWeights["\(prefix).switch_mlp.down_proj.weight"] = dp
146+
}
147+
}
148+
149+
// Format B
150+
if newWeights["\(prefix).experts.0.gate_proj.weight"] != nil {
151+
for projName in ["gate_proj", "up_proj", "down_proj"] {
152+
let perExpert = (0 ..< nExperts).compactMap {
153+
newWeights["\(prefix).experts.\($0).\(projName).weight"]
154+
}
155+
let perExpertScale = (0 ..< nExperts).compactMap {
156+
newWeights["\(prefix).experts.\($0).\(projName).weight_scale_inv"]
157+
}
158+
if perExpert.count == nExperts {
159+
if perExpertScale.count == nExperts {
160+
let bs = 128
161+
let dequanted: [MLXArray] = zip(perExpert, perExpertScale).map { w, inv in
162+
let wFp = MLXFast.fromFp8(w, dtype: .bfloat16)
163+
let (m, n) = (wFp.dim(0), wFp.dim(1))
164+
let padB = (bs - m % bs) % bs; let padS = (bs - n % bs) % bs
165+
var p = MLX.padded(wFp, widths: [[0, padB], [0, padS]])
166+
p = p.reshaped([(m + padB) / bs, bs, (n + padS) / bs, bs])
167+
return (p * inv[0..., .newAxis, 0..., .newAxis]).reshaped([m + padB, n + padS])[0 ..< m, 0 ..< n].asType(.bfloat16)
168+
}
169+
let stacked = MLX.stacked(dequanted)
170+
MLX.eval(stacked)
171+
newWeights["\(prefix).switch_mlp.\(projName).weight"] = stacked
172+
for i in 0 ..< nExperts {
173+
newWeights.removeValue(forKey: "\(prefix).experts.\(i).\(projName).weight")
174+
newWeights.removeValue(forKey: "\(prefix).experts.\(i).\(projName).weight_scale_inv")
175+
}
176+
} else {
177+
newWeights["\(prefix).switch_mlp.\(projName).weight"] = MLX.stacked(perExpert)
178+
for i in 0 ..< nExperts {
179+
newWeights.removeValue(forKey: "\(prefix).experts.\(i).\(projName).weight")
180+
}
181+
}
182+
}
90183
}
91184
}
92185
}
93186
}
94187

188+
// ── Step 5: Eager FP8 block-wise dequantization for remaining non-expert Linear layers ──
189+
// After Steps 3+4, ALL switch_mlp expert scale tensors have been consumed during stacking.
190+
// Any remaining "weight_scale_inv" keys belong to regular Linear layers
191+
// (attention projections, shared_expert, GatedDeltaNet, lm_head, etc.).
192+
// These cannot carry weight_scale_inv, so we eagerly dequantize here.
193+
var processed = [String: MLXArray]()
194+
for (key, value) in newWeights {
195+
if key.hasSuffix(".weight_scale_inv") {
196+
let wKey = key.replacingOccurrences(of: "_scale_inv", with: "")
197+
if let w = newWeights[wKey], processed[wKey] == nil {
198+
// Swift MLX maps F8_E4M3 → uint8; fromFp8 gives proper signed floats.
199+
let wFp: MLXArray = MLXFast.fromFp8(w, dtype: .bfloat16)
200+
let bs = 128
201+
let (m, n) = (wFp.dim(0), wFp.dim(1))
202+
let padBottom = (bs - m % bs) % bs
203+
let padSide = (bs - n % bs) % bs
204+
var padded = MLX.padded(wFp, widths: [[0, padBottom], [0, padSide]])
205+
padded = padded.reshaped([(m + padBottom) / bs, bs, (n + padSide) / bs, bs])
206+
let scaled = padded * value[0..., .newAxis, 0..., .newAxis]
207+
let dequant = scaled.reshaped([m + padBottom, n + padSide])[0 ..< m, 0 ..< n]
208+
processed[wKey] = dequant.asType(.bfloat16)
209+
}
210+
// Drop the scale tensor — Linear has no slot for it.
211+
} else if processed[key] == nil {
212+
processed[key] = value
213+
}
214+
}
215+
if !processed.isEmpty { newWeights = processed }
216+
217+
95218
return languageModel.sanitize(weights: newWeights)
96219
}
220+
97221
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,19 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
11261126

11271127
// Save future MTP logits for next iteration
11281128
self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil
1129-
1129+
1130+
// Force evaluation of MTP state to prevent graph collapse
1131+
var evalArrays = [token]
1132+
if let mtpLogits = self.mtpLogits { evalArrays.append(contentsOf: mtpLogits) }
1133+
eval(evalArrays)
1134+
1135+
pendingTokens.append(token.item(Int.self))
1136+
y = .init(tokens: token)
1137+
11301138
quantizeKVCache(&cache)
1139+
for i in mtpCaches.indices {
1140+
quantizeKVCache(&mtpCaches[i])
1141+
}
11311142
return
11321143
}
11331144

@@ -1163,8 +1174,7 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
11631174
mainTokens = sampler.sample(logits: verifyLogits)
11641175
}
11651176

1166-
// Compare and accept proposed tokens
1167-
eval(mainTokens, draftTokens)
1177+
// We defer eval() until after we compute mtpLogits to force the graph
11681178
let mainTokensList = mainTokens.asArray(Int.self)
11691179
let draftTokensList = concatenated(draftTokens).asArray(Int.self)
11701180
var accepted = 0
@@ -1191,6 +1201,9 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
11911201

11921202
// Apply dynamic cache quantization after rewind
11931203
quantizeKVCache(&cache)
1204+
for i in mtpCaches.indices {
1205+
quantizeKVCache(&mtpCaches[i])
1206+
}
11941207

11951208
// Set y for the next round
11961209
y = .init(tokens: finalToken)
@@ -1203,6 +1216,11 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
12031216
} else {
12041217
self.mtpLogits = nil
12051218
}
1219+
1220+
// Force evaluation of MTP state to prevent graph collapse
1221+
var evalArrays = [mainTokens] + draftTokens
1222+
if let mtpLogits = self.mtpLogits { evalArrays.append(contentsOf: mtpLogits) }
1223+
eval(evalArrays)
12061224
}
12071225

12081226
mutating public func next() -> Int? {

Libraries/MLXLMCommon/Load.swift

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,41 @@ public func loadWeights(
8989
// and fall back to the bare path if none match.
9090
let knownPrefixes = ["language_model.", "model.language_model.", ""]
9191
for (path, module) in model.leafModules().flattened() {
92-
if let qsl = module as? QuantizedSwitchLinear {
92+
if let sl = module as? SwitchLinear {
9393
let bareName = "\(path).weight"
94-
// Find the original key that exists in the shard index
94+
95+
// First, check for unstacked format (e.g. Qwen FP8: "experts.N.gate_proj")
96+
if bareName.contains(".switch_mlp.") {
97+
let unstackedBaseName = bareName.replacingOccurrences(of: ".switch_mlp.", with: ".experts.")
98+
// Try to find expert 0 to confirm unstacked format
99+
let expert0Name = unstackedBaseName.replacingOccurrences(of: ".experts.", with: ".experts.0.")
100+
101+
var foundUnstacked = false
102+
for prefix in knownPrefixes {
103+
if ExpertStreamerManager.shared?.getFile(for: prefix + expert0Name) != nil {
104+
foundUnstacked = true
105+
var map = [Int: (path: String, tensorName: String)]()
106+
for i in 0 ..< sl.numExperts {
107+
let expertName = unstackedBaseName.replacingOccurrences(of: ".experts.", with: ".experts.\(i).")
108+
let fullKey = prefix + expertName
109+
if let file = ExpertStreamerManager.shared?.getFile(for: fullKey),
110+
let dir = ExpertStreamingConfig.shared.modelDirectory {
111+
map[i] = (dir.appendingPathComponent(file).path, fullKey)
112+
}
113+
}
114+
sl.unstackedSSDMap = map
115+
break
116+
}
117+
}
118+
if foundUnstacked { continue }
119+
}
120+
121+
// Normal stacked format
95122
let originalKey = knownPrefixes.lazy
96123
.map { $0 + bareName }
97124
.first { ExpertStreamerManager.shared?.getFile(for: $0) != nil }
98-
?? bareName // fallback: use bare name (works when model has no VLM wrapper)
99-
qsl.tensorName = originalKey
125+
?? bareName // fallback: use bare name
126+
sl.tensorName = originalKey
100127
}
101128
}
102129
}

0 commit comments

Comments
 (0)