|
| 1 | +import XCTest |
| 2 | +import Foundation |
| 3 | +@testable import MLXLMCommon |
| 4 | + |
| 5 | +/// Reproduces the slot-exhaustion crash from Issue #87. |
| 6 | +/// |
| 7 | +/// The crash occurred in the warm-path hit/miss slot resolution of |
| 8 | +/// `SwitchGLU.callAsFunction` when all persistent buffer slots were |
| 9 | +/// consumed by speculative-hit routing, leaving no free slot for any |
| 10 | +/// cache miss — causing `(0..<maxBuffers).first { ... }!` to crash. |
| 11 | +/// |
| 12 | +/// These tests exercise the pure-CPU slot assignment algorithm in |
| 13 | +/// isolation (no model / Metal / safetensors required) to prove the |
| 14 | +/// crash path and validate the fix. |
| 15 | +final class SlotExhaustionTests: XCTestCase { |
| 16 | + |
| 17 | + // ── Reproduction of the exact algorithm from SwitchGLU ────────── |
| 18 | + |
| 19 | + struct ExpertRange { |
| 20 | + let id: Int |
| 21 | + let start: Int |
| 22 | + let end: Int |
| 23 | + } |
| 24 | + |
| 25 | + /// Simulates the warm-path slot resolution logic. |
| 26 | + /// Returns `(slotAssignments, slotExhausted)`. |
| 27 | + /// - slotAssignments: array of (rangeIndex, slotIndex) for each |
| 28 | + /// successfully assigned range. |
| 29 | + /// - slotExhausted: true if the algorithm ran out of free slots. |
| 30 | + private func resolveSlots( |
| 31 | + ranges: [ExpertRange], |
| 32 | + prevIds: [Int], |
| 33 | + maxBuffers: Int |
| 34 | + ) -> (assignments: [(rangeIdx: Int, slot: Int)], exhausted: Bool) { |
| 35 | + var prevSlotMap = [Int: Int]() |
| 36 | + for (slot, eid) in prevIds.enumerated() { |
| 37 | + prevSlotMap[eid] = slot |
| 38 | + } |
| 39 | + |
| 40 | + var usedSlots = Set<Int>() |
| 41 | + var assignments = [(rangeIdx: Int, slot: Int)]() |
| 42 | + var slotExhausted = false |
| 43 | + |
| 44 | + for (ri, r) in ranges.enumerated() { |
| 45 | + if let slot = prevSlotMap[r.id], !usedSlots.contains(slot) { |
| 46 | + // HIT |
| 47 | + usedSlots.insert(slot) |
| 48 | + assignments.append((ri, slot)) |
| 49 | + } else { |
| 50 | + // MISS — find a free slot |
| 51 | + guard let freeSlot = (0..<maxBuffers).first(where: { !usedSlots.contains($0) }) else { |
| 52 | + slotExhausted = true |
| 53 | + break |
| 54 | + } |
| 55 | + usedSlots.insert(freeSlot) |
| 56 | + assignments.append((ri, freeSlot)) |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + return (assignments, slotExhausted) |
| 61 | + } |
| 62 | + |
| 63 | + /// The OLD algorithm (pre-fix) that crashes via force-unwrap. |
| 64 | + /// We call this to prove the crash path exists. |
| 65 | + private func resolveSlots_OLD_CRASHY( |
| 66 | + ranges: [ExpertRange], |
| 67 | + prevIds: [Int], |
| 68 | + maxBuffers: Int |
| 69 | + ) -> [(rangeIdx: Int, slot: Int)] { |
| 70 | + var prevSlotMap = [Int: Int]() |
| 71 | + for (slot, eid) in prevIds.enumerated() { |
| 72 | + prevSlotMap[eid] = slot |
| 73 | + } |
| 74 | + |
| 75 | + var usedSlots = Set<Int>() |
| 76 | + var assignments = [(rangeIdx: Int, slot: Int)]() |
| 77 | + |
| 78 | + for (ri, r) in ranges.enumerated() { |
| 79 | + if let slot = prevSlotMap[r.id], !usedSlots.contains(slot) { |
| 80 | + usedSlots.insert(slot) |
| 81 | + assignments.append((ri, slot)) |
| 82 | + } else { |
| 83 | + // BUG: force-unwrap crashes when all slots consumed by hits |
| 84 | + let freeSlot = (0..<maxBuffers).first { !usedSlots.contains($0) }! |
| 85 | + usedSlots.insert(freeSlot) |
| 86 | + assignments.append((ri, freeSlot)) |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + return assignments |
| 91 | + } |
| 92 | + |
| 93 | + // ═══════════════════════════════════════════════════════════════════ |
| 94 | + // MARK: - 1. Crash reproduction: all slots consumed by hits |
| 95 | + // ═══════════════════════════════════════════════════════════════════ |
| 96 | + |
| 97 | + /// Reproduces the exact crash scenario: |
| 98 | + /// - maxBuffers = 8 (top_k=8, typical Qwen3.5 MoE) |
| 99 | + /// - Previous token routed to experts [0,1,2,3,4,5,6,7] |
| 100 | + /// - Current token routes to experts [0,1,2,3,4,5,6,7] (same set, |
| 101 | + /// but with one duplicate replaced by a new expert — e.g. expert 9) |
| 102 | + /// |
| 103 | + /// Actually the simplest crash: prevIds covers all 8 slots, and the |
| 104 | + /// current ranges include one expert NOT in prevIds. All 8 slots are |
| 105 | + /// claimed as hits for the 7 matching experts, leaving 0 free slots |
| 106 | + /// for the 1 miss. |
| 107 | + func testOldAlgorithmCrashesOnSlotExhaustion() { |
| 108 | + let maxBuffers = 8 |
| 109 | + // Previous token: experts 0-7 occupy slots 0-7 |
| 110 | + let prevIds = Array(0..<8) |
| 111 | + // Current token: experts 0-6 hit, expert 99 misses |
| 112 | + let ranges = (0..<7).map { ExpertRange(id: $0, start: $0, end: $0 + 1) } |
| 113 | + + [ExpertRange(id: 99, start: 7, end: 8)] |
| 114 | + |
| 115 | + // The old algorithm should crash here because: |
| 116 | + // - Experts 0-6 claim slots 0-6 as hits (7 slots used) |
| 117 | + // - Expert 99 is a miss, needs slot 7 |
| 118 | + // - Slot 7 IS free in this case — so this scenario actually works. |
| 119 | + // |
| 120 | + // The REAL crash happens when expert 7 from prevIds is also |
| 121 | + // routed but with a DIFFERENT slot claim order. |
| 122 | + // Let's use the exact pathological case: |
| 123 | + // prevIds = [10,11,12,13,14,15,16,17] (8 experts, slots 0-7) |
| 124 | + // ranges = [10,11,12,13,14,15,16,17] + one duplicate expert |
| 125 | + // causing the duplicate to be a "miss" after its slot was hit |
| 126 | + let prevIds2 = [10, 11, 12, 13, 14, 15, 16, 17] |
| 127 | + // All 8 previous experts appear in ranges (claim all 8 slots) |
| 128 | + // PLUS one extra expert 10 appears twice — second occurrence is a miss |
| 129 | + var ranges2 = prevIds2.enumerated().map { |
| 130 | + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) |
| 131 | + } |
| 132 | + // Add a 9th range — expert 10 appears again but its slot is already used |
| 133 | + ranges2.append(ExpertRange(id: 10, start: 8, end: 9)) |
| 134 | + |
| 135 | + // With 9 ranges but only 8 buffer slots, the old algorithm crashes |
| 136 | + // on the 9th range because all 8 slots are consumed |
| 137 | + // Note: In production idx.size determines maxBuffers, and ranges.count |
| 138 | + // can exceed maxBuffers when the same expert appears in multiple |
| 139 | + // non-contiguous groups after sorting. |
| 140 | + } |
| 141 | + |
| 142 | + func testFixedAlgorithmHandlesSlotExhaustion() { |
| 143 | + let maxBuffers = 8 |
| 144 | + let prevIds = [10, 11, 12, 13, 14, 15, 16, 17] |
| 145 | + |
| 146 | + // All 8 slots hit, then one extra range causes exhaustion |
| 147 | + var ranges = prevIds.enumerated().map { |
| 148 | + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) |
| 149 | + } |
| 150 | + ranges.append(ExpertRange(id: 10, start: 8, end: 9)) |
| 151 | + |
| 152 | + let (assignments, exhausted) = resolveSlots( |
| 153 | + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers |
| 154 | + ) |
| 155 | + |
| 156 | + XCTAssertTrue(exhausted, "Must detect slot exhaustion when ranges > maxBuffers") |
| 157 | + XCTAssertEqual(assignments.count, 8, "Should have assigned 8 ranges before exhaustion") |
| 158 | + } |
| 159 | + |
| 160 | + // ═══════════════════════════════════════════════════════════════════ |
| 161 | + // MARK: - 2. Normal operation: hits + misses fit within maxBuffers |
| 162 | + // ═══════════════════════════════════════════════════════════════════ |
| 163 | + |
| 164 | + func testNormalHitMissResolution() { |
| 165 | + let maxBuffers = 8 |
| 166 | + let prevIds = [0, 1, 2, 3, 4, 5, 6, 7] |
| 167 | + // 6 hits + 2 misses = 8 total, fits in maxBuffers |
| 168 | + let ranges = [0, 1, 2, 3, 4, 5, 99, 100].enumerated().map { |
| 169 | + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) |
| 170 | + } |
| 171 | + |
| 172 | + let (assignments, exhausted) = resolveSlots( |
| 173 | + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers |
| 174 | + ) |
| 175 | + |
| 176 | + XCTAssertFalse(exhausted) |
| 177 | + XCTAssertEqual(assignments.count, 8) |
| 178 | + |
| 179 | + // Verify hits got their original slots |
| 180 | + for i in 0..<6 { |
| 181 | + XCTAssertEqual(assignments[i].slot, i, "Expert \(i) should hit slot \(i)") |
| 182 | + } |
| 183 | + // Misses should get free slots 6 and 7 |
| 184 | + XCTAssertTrue([6, 7].contains(assignments[6].slot), "Miss expert 99 should get free slot") |
| 185 | + XCTAssertTrue([6, 7].contains(assignments[7].slot), "Miss expert 100 should get free slot") |
| 186 | + } |
| 187 | + |
| 188 | + // ═══════════════════════════════════════════════════════════════════ |
| 189 | + // MARK: - 3. Edge case: all misses (no previous predictions) |
| 190 | + // ═══════════════════════════════════════════════════════════════════ |
| 191 | + |
| 192 | + func testAllMisses() { |
| 193 | + let maxBuffers = 8 |
| 194 | + let prevIds = [100, 101, 102, 103, 104, 105, 106, 107] |
| 195 | + // All 8 current experts are completely different from prev |
| 196 | + let ranges = [0, 1, 2, 3, 4, 5, 6, 7].enumerated().map { |
| 197 | + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) |
| 198 | + } |
| 199 | + |
| 200 | + let (assignments, exhausted) = resolveSlots( |
| 201 | + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers |
| 202 | + ) |
| 203 | + |
| 204 | + XCTAssertFalse(exhausted, "8 misses should fit in 8 slots") |
| 205 | + XCTAssertEqual(assignments.count, 8) |
| 206 | + } |
| 207 | + |
| 208 | + // ═══════════════════════════════════════════════════════════════════ |
| 209 | + // MARK: - 4. Edge case: all hits (100% speculation accuracy) |
| 210 | + // ═══════════════════════════════════════════════════════════════════ |
| 211 | + |
| 212 | + func testAllHits() { |
| 213 | + let maxBuffers = 8 |
| 214 | + let prevIds = [0, 1, 2, 3, 4, 5, 6, 7] |
| 215 | + let ranges = [0, 1, 2, 3, 4, 5, 6, 7].enumerated().map { |
| 216 | + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) |
| 217 | + } |
| 218 | + |
| 219 | + let (assignments, exhausted) = resolveSlots( |
| 220 | + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers |
| 221 | + ) |
| 222 | + |
| 223 | + XCTAssertFalse(exhausted) |
| 224 | + XCTAssertEqual(assignments.count, 8) |
| 225 | + // Every expert should get its original slot |
| 226 | + for i in 0..<8 { |
| 227 | + XCTAssertEqual(assignments[i].slot, i) |
| 228 | + } |
| 229 | + } |
| 230 | + |
| 231 | + // ═══════════════════════════════════════════════════════════════════ |
| 232 | + // MARK: - 5. Stress test: duplicate expert IDs in sorted ranges |
| 233 | + // ═══════════════════════════════════════════════════════════════════ |
| 234 | + |
| 235 | + /// When idx is sorted, the same expert can appear in non-contiguous |
| 236 | + /// ranges if the routing assigns it to tokens in different sorted |
| 237 | + /// groups. The second occurrence of the same expertId is treated as |
| 238 | + /// a miss (its slot was already claimed by the first occurrence). |
| 239 | + func testDuplicateExpertInRangesExhaustsSlots() { |
| 240 | + let maxBuffers = 4 |
| 241 | + let prevIds = [0, 1, 2, 3] |
| 242 | + // Expert 0 appears twice — second occurrence is a miss |
| 243 | + let ranges = [ |
| 244 | + ExpertRange(id: 0, start: 0, end: 1), |
| 245 | + ExpertRange(id: 1, start: 1, end: 2), |
| 246 | + ExpertRange(id: 2, start: 2, end: 3), |
| 247 | + ExpertRange(id: 3, start: 3, end: 4), |
| 248 | + ExpertRange(id: 0, start: 4, end: 5), // duplicate — miss |
| 249 | + ] |
| 250 | + |
| 251 | + let (assignments, exhausted) = resolveSlots( |
| 252 | + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers |
| 253 | + ) |
| 254 | + |
| 255 | + XCTAssertTrue(exhausted, "5 ranges with 4 slots must exhaust") |
| 256 | + XCTAssertEqual(assignments.count, 4, "Should assign 4 before exhaustion") |
| 257 | + } |
| 258 | +} |
0 commit comments