Skip to content

Commit 2c2cd9e

Browse files
authored
Merge pull request #38 from SharpAI/fix/ssd-streaming-crash-recovery
Recover from SSD streaming errors without crashing
2 parents 2b3f92d + 38d7ff2 commit 2c2cd9e

3 files changed

Lines changed: 155 additions & 12 deletions

File tree

Libraries/MLXLMCommon/ConcurrentError.swift

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,87 @@
11
import Foundation
22
import MLX
33

4+
/// Error thrown when SSD expert streaming encounters a corrupted, truncated,
5+
/// or incomplete safetensors file during pread I/O.
6+
public struct SSDStreamingError: Error, LocalizedError {
7+
public let underlyingError: Error
8+
9+
public init(underlyingError: Error) {
10+
self.underlyingError = underlyingError
11+
}
12+
13+
public var errorDescription: String? {
14+
"MLX SSD Streaming Error: \(underlyingError.localizedDescription). The model safetensors file may be corrupted, truncated, or incomplete. Try re-downloading the model."
15+
}
16+
}
17+
18+
private enum SSDStreamingErrorLatchContext {
19+
static let threadDictionaryKey = "MLXLMCommon.SSDStreamingErrorLatch.active"
20+
}
21+
22+
/// Error latch for SSD streaming errors that occur inside non-throwing
23+
/// `callAsFunction` paths. A generation installs its own active latch around
24+
/// model execution so concurrent sessions do not cross-contaminate each other.
25+
public final class SSDStreamingErrorLatch: @unchecked Sendable {
26+
public static let shared = SSDStreamingErrorLatch()
27+
private let lock = NSLock()
28+
private var _error: Error?
29+
30+
public init() {}
31+
32+
package static func withActive<T>(_ latch: SSDStreamingErrorLatch, _ body: () throws -> T) rethrows -> T {
33+
let key = SSDStreamingErrorLatchContext.threadDictionaryKey as NSString
34+
let threadDictionary = Thread.current.threadDictionary
35+
let previous = threadDictionary[key]
36+
threadDictionary[key] = latch
37+
defer {
38+
if let previous {
39+
threadDictionary[key] = previous
40+
} else {
41+
threadDictionary.removeObject(forKey: key)
42+
}
43+
}
44+
return try body()
45+
}
46+
47+
package static var active: SSDStreamingErrorLatch? {
48+
let key = SSDStreamingErrorLatchContext.threadDictionaryKey as NSString
49+
return Thread.current.threadDictionary[key] as? SSDStreamingErrorLatch
50+
}
51+
52+
/// Record an error (first-wins semantics).
53+
public func set(_ error: Error) {
54+
lock.withLock {
55+
if _error == nil { _error = error }
56+
}
57+
}
58+
59+
/// Consume and return the recorded error, resetting the latch.
60+
/// Returns nil if no error was recorded.
61+
public func consume() -> Error? {
62+
lock.withLock {
63+
let e = _error
64+
_error = nil
65+
return e
66+
}
67+
}
68+
69+
/// Throw the recorded error if one exists, then clear it.
70+
public func throwIfSet() throws {
71+
if let error = consume() {
72+
throw error
73+
}
74+
}
75+
}
76+
477
package final class ThreadSafeError: @unchecked Sendable {
578
package let lock = NSLock()
679
package var error: Swift.Error?
80+
private let latch: SSDStreamingErrorLatch?
781

8-
package init() {}
82+
package init(latch: SSDStreamingErrorLatch? = SSDStreamingErrorLatch.active) {
83+
self.latch = latch
84+
}
985

1086
package func catchError(_ block: () throws -> Void) {
1187
do {
@@ -19,9 +95,20 @@ package final class ThreadSafeError: @unchecked Sendable {
1995
}
2096
}
2197

22-
package func check() {
98+
/// Check if any error was recorded during concurrent I/O.
99+
///
100+
/// Instead of calling `fatalError` (which crashes the entire app), this
101+
/// posts the error to the global `SSDStreamingErrorLatch` so the generation
102+
/// loop can detect it after the current token and surface it gracefully
103+
/// in the UI (e.g., prompting a re-download).
104+
@discardableResult
105+
package func check() -> SSDStreamingError? {
23106
if let error = error {
24-
fatalError("MLX SSD Streaming Error: \(error.localizedDescription). (The model safetensors file may be corrupted, truncated, or incomplete).")
107+
let streamingError = SSDStreamingError(underlyingError: error)
108+
latch?.set(streamingError)
109+
SSDStreamingErrorLatch.shared.set(streamingError)
110+
return streamingError
25111
}
112+
return nil
26113
}
27114
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ protocol TokenIteratorProtocol: Sequence, IteratorProtocol where Element == Int
502502
var maxTokens: Int? { get }
503503
var tokenCount: Int { get }
504504
var promptPrefillTime: TimeInterval { get }
505+
var streamingError: SSDStreamingError? { get }
505506
}
506507

507508
/// Generator of tokens.
@@ -546,6 +547,8 @@ public struct TokenIterator: TokenIteratorProtocol {
546547

547548
// Internal metrics
548549
var promptPrefillTime: TimeInterval = 0.0
550+
var streamingError: SSDStreamingError?
551+
let ssdErrorLatch = SSDStreamingErrorLatch()
549552

550553
/// Initialize a `TokenIterator` with the given tokens. Note: this has been
551554
/// replaced with ``init(input:model:cache:parameters:)``.
@@ -646,16 +649,25 @@ public struct TokenIterator: TokenIteratorProtocol {
646649
mutating func prepare(input: LMInput, windowSize: Int? = nil) throws {
647650
processor?.prompt(input.text.tokens)
648651

649-
switch try model.prepare(input, cache: cache, windowSize: windowSize) {
652+
let preparation = try SSDStreamingErrorLatch.withActive(ssdErrorLatch) {
653+
try model.prepare(input, cache: cache, windowSize: windowSize)
654+
}
655+
656+
switch preparation {
650657
case .tokens(let tokens):
651658
y = tokens
652659

660+
try ssdErrorLatch.throwIfSet()
661+
653662
// evaluate the remainder of the prompt -- this primes the pump
654-
let token = step(previous: y)
663+
let token = try step(previous: y)
664+
655665
y = .init(tokens: token)
656666
asyncEval(y.tokens)
657667

658668
case .logits(let result):
669+
try ssdErrorLatch.throwIfSet()
670+
659671
y = .init(tokens: convertToToken(logits: result.logits))
660672
asyncEval(y.tokens)
661673

@@ -677,11 +689,14 @@ public struct TokenIterator: TokenIteratorProtocol {
677689
}
678690

679691
/// Evaluate the next token and return the new token (y), updating cache state
680-
mutating func step(previous: LMInput.Text) -> MLXArray {
681-
let result = model(
682-
previous[text: .newAxis], cache: cache.isEmpty ? nil : cache, state: state)
692+
mutating func step(previous: LMInput.Text) throws -> MLXArray {
693+
let result = SSDStreamingErrorLatch.withActive(ssdErrorLatch) {
694+
model(previous[text: .newAxis], cache: cache.isEmpty ? nil : cache, state: state)
695+
}
683696
self.state = result.state
684697

698+
try ssdErrorLatch.throwIfSet()
699+
685700
// Apply dynamic cache quantization after each step
686701
maybeQuantizeKVCache(
687702
cache: &cache,
@@ -694,6 +709,10 @@ public struct TokenIterator: TokenIteratorProtocol {
694709
}
695710

696711
mutating public func next() -> Int? {
712+
if streamingError != nil {
713+
return nil
714+
}
715+
697716
if let maxTokens, tokenCount >= maxTokens {
698717
return nil
699718
}
@@ -702,7 +721,17 @@ public struct TokenIterator: TokenIteratorProtocol {
702721
let previousY = y
703722

704723
// compute the next state and async eval the next token
705-
let token = step(previous: previousY)
724+
let token: MLXArray
725+
do {
726+
token = try step(previous: previousY)
727+
} catch let error as SSDStreamingError {
728+
streamingError = error
729+
return nil
730+
} catch {
731+
streamingError = SSDStreamingError(underlyingError: error)
732+
return nil
733+
}
734+
706735
y = .init(tokens: token)
707736
asyncEval(token)
708737

@@ -746,6 +775,7 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {
746775
let draftModel: any LanguageModel
747776

748777
var mainState: LMOutput.State?
778+
public let streamingError: SSDStreamingError? = nil
749779
var mainCache: [KVCache]
750780
var draftCache: [KVCache]
751781
let quantizeKVCache: (inout [KVCache]) -> Void
@@ -1685,7 +1715,7 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
16851715
// Launch a Task to perform iteration asynchronously.
16861716
let task = Task {
16871717
let performIteration = {
1688-
let iterator = iterator.consume()
1718+
var iterator = iterator.consume()
16891719
var handler = handler.consume()
16901720

16911721
var start = Date.timeIntervalSinceReferenceDate
@@ -1698,7 +1728,7 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
16981728
tokenizer: tokenizer
16991729
)
17001730

1701-
for token in iterator {
1731+
while let token = iterator.next() {
17021732
// Check for cancellation on every loop iteration.
17031733
if Task.isCancelled {
17041734
stopReason = .cancelled
@@ -1732,7 +1762,7 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
17321762
}
17331763

17341764
if stopReason == nil {
1735-
if Task.isCancelled {
1765+
if Task.isCancelled || iterator.streamingError != nil {
17361766
stopReason = .cancelled
17371767
} else if let maxTokens = iterator.maxTokens, tokenCount >= maxTokens {
17381768
stopReason = .length

Tests/MLXLMTests/CorruptSafetensorsTests.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,32 @@ import Testing
55

66
@Suite
77
struct CorruptSafetensorsTests {
8+
@Test
9+
func testThreadSafeErrorCheckPublishesToActiveLatch() throws {
10+
let latch = SSDStreamingErrorLatch()
11+
12+
SSDStreamingErrorLatch.withActive(latch) {
13+
let errState = ThreadSafeError()
14+
errState.catchError {
15+
throw NSError(domain: "CorruptSafetensorsTests", code: 13, userInfo: [
16+
NSLocalizedDescriptionKey: "truncated shard"
17+
])
18+
}
19+
20+
let latched = errState.check()
21+
#expect(latched != nil)
22+
}
23+
24+
do {
25+
try latch.throwIfSet()
26+
Issue.record("Expected latch.throwIfSet() to surface an SSDStreamingError")
27+
} catch let error as SSDStreamingError {
28+
#expect(error.localizedDescription.contains("truncated shard"))
29+
} catch {
30+
Issue.record("Unexpected error type: \(error)")
31+
}
32+
}
33+
834
@Test
935
func testDeadlock() throws {
1036
let tempDir = FileManager.default.temporaryDirectory

0 commit comments

Comments
 (0)