@@ -502,6 +502,7 @@ protocol TokenIteratorProtocol: Sequence, IteratorProtocol where Element == Int
502502 var maxTokens : Int ? { get }
503503 var tokenCount : Int { get }
504504 var promptPrefillTime : TimeInterval { get }
505+ var streamingError : SSDStreamingError ? { get }
505506}
506507
507508/// Generator of tokens.
@@ -546,6 +547,8 @@ public struct TokenIterator: TokenIteratorProtocol {
546547
547548 // Internal metrics
548549 var promptPrefillTime : TimeInterval = 0.0
550+ var streamingError : SSDStreamingError ?
551+ let ssdErrorLatch = SSDStreamingErrorLatch ( )
549552
550553 /// Initialize a `TokenIterator` with the given tokens. Note: this has been
551554 /// replaced with ``init(input:model:cache:parameters:)``.
@@ -646,16 +649,25 @@ public struct TokenIterator: TokenIteratorProtocol {
646649 mutating func prepare( input: LMInput , windowSize: Int ? = nil ) throws {
647650 processor? . prompt ( input. text. tokens)
648651
649- switch try model. prepare ( input, cache: cache, windowSize: windowSize) {
652+ let preparation = try SSDStreamingErrorLatch . withActive ( ssdErrorLatch) {
653+ try model. prepare ( input, cache: cache, windowSize: windowSize)
654+ }
655+
656+ switch preparation {
650657 case . tokens( let tokens) :
651658 y = tokens
652659
660+ try ssdErrorLatch. throwIfSet ( )
661+
653662 // evaluate the remainder of the prompt -- this primes the pump
654- let token = step ( previous: y)
663+ let token = try step ( previous: y)
664+
655665 y = . init( tokens: token)
656666 asyncEval ( y. tokens)
657667
658668 case . logits( let result) :
669+ try ssdErrorLatch. throwIfSet ( )
670+
659671 y = . init( tokens: convertToToken ( logits: result. logits) )
660672 asyncEval ( y. tokens)
661673
@@ -677,11 +689,14 @@ public struct TokenIterator: TokenIteratorProtocol {
677689 }
678690
679691 /// Evaluate the next token and return the new token (y), updating cache state
680- mutating func step( previous: LMInput . Text ) -> MLXArray {
681- let result = model (
682- previous [ text: . newAxis] , cache: cache. isEmpty ? nil : cache, state: state)
692+ mutating func step( previous: LMInput . Text ) throws -> MLXArray {
693+ let result = SSDStreamingErrorLatch . withActive ( ssdErrorLatch) {
694+ model ( previous [ text: . newAxis] , cache: cache. isEmpty ? nil : cache, state: state)
695+ }
683696 self . state = result. state
684697
698+ try ssdErrorLatch. throwIfSet ( )
699+
685700 // Apply dynamic cache quantization after each step
686701 maybeQuantizeKVCache (
687702 cache: & cache,
@@ -694,6 +709,10 @@ public struct TokenIterator: TokenIteratorProtocol {
694709 }
695710
696711 mutating public func next( ) -> Int ? {
712+ if streamingError != nil {
713+ return nil
714+ }
715+
697716 if let maxTokens, tokenCount >= maxTokens {
698717 return nil
699718 }
@@ -702,7 +721,17 @@ public struct TokenIterator: TokenIteratorProtocol {
702721 let previousY = y
703722
704723 // compute the next state and async eval the next token
705- let token = step ( previous: previousY)
724+ let token : MLXArray
725+ do {
726+ token = try step ( previous: previousY)
727+ } catch let error as SSDStreamingError {
728+ streamingError = error
729+ return nil
730+ } catch {
731+ streamingError = SSDStreamingError ( underlyingError: error)
732+ return nil
733+ }
734+
706735 y = . init( tokens: token)
707736 asyncEval ( token)
708737
@@ -746,6 +775,7 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {
746775 let draftModel : any LanguageModel
747776
748777 var mainState : LMOutput . State ?
778+ public let streamingError : SSDStreamingError ? = nil
749779 var mainCache : [ KVCache ]
750780 var draftCache : [ KVCache ]
751781 let quantizeKVCache : ( inout [ KVCache ] ) -> Void
@@ -1685,7 +1715,7 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
16851715 // Launch a Task to perform iteration asynchronously.
16861716 let task = Task {
16871717 let performIteration = {
1688- let iterator = iterator. consume ( )
1718+ var iterator = iterator. consume ( )
16891719 var handler = handler. consume ( )
16901720
16911721 var start = Date . timeIntervalSinceReferenceDate
@@ -1698,7 +1728,7 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
16981728 tokenizer: tokenizer
16991729 )
17001730
1701- for token in iterator {
1731+ while let token = iterator. next ( ) {
17021732 // Check for cancellation on every loop iteration.
17031733 if Task . isCancelled {
17041734 stopReason = . cancelled
@@ -1732,7 +1762,7 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
17321762 }
17331763
17341764 if stopReason == nil {
1735- if Task . isCancelled {
1765+ if Task . isCancelled || iterator . streamingError != nil {
17361766 stopReason = . cancelled
17371767 } else if let maxTokens = iterator. maxTokens, tokenCount >= maxTokens {
17381768 stopReason = . length
0 commit comments