Skip to content

Commit 7f7fdb0

Browse files
author
Aegis-AI
committed
fix(moe): prevent crash when persistent buffer slots are exhausted
When all buffer slots are claimed by speculative-hit routing (ranges.count == maxBuffers and all experts get different slot assignments), the force-unwrap on '.first { !usedSlots.contains($0) }!' returns nil and crashes with _assertionFailure. Replace the force-unwrap with a guard that sets a slotExhausted flag and breaks out. When detected, the hit/miss arrays are cleared and we fall through to the existing full-pread fallback path — same correctness, no crash. Fixes SharpAI/SwiftLM#87
1 parent 86e3f93 commit 7f7fdb0

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

Libraries/MLXLMCommon/SwitchLayers.swift

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
713713

714714
var usedSlots = Set<Int>()
715715
var missInfo = [(rangeIdx: Int, expertId: Int, bufferSlot: Int)]()
716+
var slotExhausted = false
716717

717718
for (ri, r) in ranges.enumerated() {
718719
if let slot = prevSlotMap[r.id], !usedSlots.contains(slot) {
@@ -723,7 +724,12 @@ public class SwitchGLU: Module, @unchecked Sendable {
723724
usedSlots.insert(slot)
724725
} else {
725726
// MISS: find a free slot
726-
let freeSlot = (0..<maxBuffers).first { !usedSlots.contains($0) }!
727+
guard let freeSlot = (0..<maxBuffers).first(where: { !usedSlots.contains($0) }) else {
728+
// All buffer slots exhausted — fall through to
729+
// full-pread path below (Issue #87)
730+
slotExhausted = true
731+
break
732+
}
727733
usedGate.append(_persistentGate![freeSlot])
728734
usedUp.append(_persistentUp![freeSlot])
729735
usedDown.append(_persistentDown![freeSlot])
@@ -733,7 +739,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
733739
}
734740

735741
// Pread only misses (~30% of experts, ~6 reads at QD=6)
736-
if !missInfo.isEmpty {
742+
if !slotExhausted && !missInfo.isEmpty {
737743
let totalMissReads = missInfo.count * 3
738744
let errState = ThreadSafeError()
739745
DispatchQueue.concurrentPerform(iterations: totalMissReads) { [missInfo] i in
@@ -762,8 +768,13 @@ public class SwitchGLU: Module, @unchecked Sendable {
762768
}
763769
errState.check()
764770
}
765-
} else {
766-
// No predictions available — full pread fallback
771+
}
772+
773+
// Slot exhaustion or no predictions — full pread fallback
774+
if usedGate.count != ranges.count {
775+
usedGate.removeAll()
776+
usedUp.removeAll()
777+
usedDown.removeAll()
767778
for i in 0..<ranges.count {
768779
usedGate.append(_persistentGate![i])
769780
usedUp.append(_persistentUp![i])

0 commit comments

Comments
 (0)