Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion Libraries/MLXLMCommon/ConcurrentError.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
import Foundation
import MLX

/// Error thrown when SSD expert streaming encounters a corrupted, truncated,
/// or incomplete safetensors file during pread I/O.
public struct SSDStreamingError: Error, LocalizedError {
public let underlyingError: Error

public var errorDescription: String? {
"MLX SSD Streaming Error: \(underlyingError.localizedDescription). The model safetensors file may be corrupted, truncated, or incomplete. Try re-downloading the model."
}
}

/// Global error latch for SSD streaming errors that occur inside non-throwing
/// `callAsFunction` paths. Set by `ThreadSafeError.check()`, cleared and
/// inspected by the generation loop after each token.
public final class SSDStreamingErrorLatch: @unchecked Sendable {
public static let shared = SSDStreamingErrorLatch()
private let lock = NSLock()
private var _error: Error?

Comment on lines +25 to +29
/// Record an error (first-wins semantics).
public func set(_ error: Error) {
lock.withLock {
if _error == nil { _error = error }
}
}

/// Consume and return the recorded error, resetting the latch.
/// Returns nil if no error was recorded.
public func consume() -> Error? {
lock.withLock {
let e = _error
_error = nil
return e
}
}

/// Throw the recorded error if one exists, then clear it.
public func throwIfSet() throws {
if let error = consume() {
throw error
}
}
}

package final class ThreadSafeError: @unchecked Sendable {
package let lock = NSLock()
package var error: Swift.Error?
Expand All @@ -19,9 +62,17 @@ package final class ThreadSafeError: @unchecked Sendable {
}
}

/// Check if any error was recorded during concurrent I/O.
///
/// Instead of calling `fatalError` (which crashes the entire app), this
/// posts the error to the global `SSDStreamingErrorLatch` so the generation
/// loop can detect it after the current token and surface it gracefully
/// in the UI (e.g., prompting a re-download).
package func check() {
if let error = error {
fatalError("MLX SSD Streaming Error: \(error.localizedDescription). (The model safetensors file may be corrupted, truncated, or incomplete).")
SSDStreamingErrorLatch.shared.set(
SSDStreamingError(underlyingError: error)
)
}
Comment on lines +98 to 111
}
}
21 changes: 21 additions & 0 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,24 @@ public struct TokenIterator: TokenIteratorProtocol {
case .tokens(let tokens):
y = tokens

// Check for SSD streaming errors that occurred during prefill.
// The MoE expert pread path uses a non-throwing callAsFunction,
// so errors are posted to the global latch instead.
try SSDStreamingErrorLatch.shared.throwIfSet()

// evaluate the remainder of the prompt -- this primes the pump
let token = step(previous: y)

// Check again after step() which also runs through MoE layers
try SSDStreamingErrorLatch.shared.throwIfSet()

y = .init(tokens: token)
asyncEval(y.tokens)

case .logits(let result):
// Check for SSD streaming errors during logits computation
try SSDStreamingErrorLatch.shared.throwIfSet()

y = .init(tokens: convertToToken(logits: result.logits))
asyncEval(y.tokens)

Expand Down Expand Up @@ -1705,6 +1717,15 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
break
}

// Check for SSD streaming errors (truncated/corrupted safetensors).
// These are set by ThreadSafeError.check() inside SwitchGLU's non-throwing
// callAsFunction path via the global error latch.
if let ssdError = SSDStreamingErrorLatch.shared.consume() {
print("[MLXLMCommon] SSD streaming error detected: \(ssdError.localizedDescription)")
stopReason = .cancelled
break
}

if promptTime == 0 {
let now = Date.timeIntervalSinceReferenceDate
promptTime = now - start
Expand Down
Loading