Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Sources/WhisperKit/Core/Text/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ open class SuppressTokensFilter: LogitsFiltering {
private let suppressTokenIndexes: [[NSNumber]]

public init(suppressTokens: [Int]) {
self.suppressTokens = suppressTokens
self.suppressTokenIndexes = suppressTokens.map { [0, 0, $0 as NSNumber] }
// Defence in depth: refuse negative token ids here too. A negative
// sentinel reaching the MLMultiArray.fill path yields a write at a
// negative offset from dataPointer.
// ref: https://github.com/argmaxinc/argmax-oss-swift/issues/392
self.suppressTokens = suppressTokens.filter { $0 >= 0 }
self.suppressTokenIndexes = self.suppressTokens.map { [0, 0, $0 as NSNumber] }
}

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
Expand Down
8 changes: 7 additions & 1 deletion Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,13 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
}

if !options.supressTokens.isEmpty {
let filteredSupressTokens = options.supressTokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin }
// Drop the OpenAI-reference sentinel "-1" (meaning "suppress all
// special tokens") and any other out-of-range values. Passing a
// negative index through SuppressTokensFilter produces a write at
// `dataPointer - strides[2]`, which on iOS 26 (where CoreML
// returns read-only MLMultiArray) triggers SIGBUS.
// ref: https://github.com/argmaxinc/argmax-oss-swift/issues/392
let filteredSupressTokens = options.supressTokens.filter { $0 >= 0 && $0 < tokenizer.specialTokens.specialTokenBegin }
allFilters.append(SuppressTokensFilter(suppressTokens: filteredSupressTokens))
}

Expand Down
8 changes: 8 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,14 @@ final class UnitTests: XCTestCase {
let logits3 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result3 = tokensFilter3.filterLogits(logits3, withTokens: [])
XCTAssertEqual(result3.data(for: 2), [-FloatType.infinity, 0.2, -FloatType.infinity, 0.4, 0.5, -FloatType.infinity, -FloatType.infinity])

// Negative sentinels (e.g. -1 used as "suppress all special" in some callers)
// must not reach the fill path, or they become negative offsets from
// dataPointer. ref: #392
let tokensFilter4 = SuppressTokensFilter(suppressTokens: [-1, 0, -5, 3])
let logits4 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result4 = tokensFilter4.filterLogits(logits4, withTokens: [])
XCTAssertEqual(result4.data(for: 2), [-FloatType.infinity, 0.2, 0.3, -FloatType.infinity, 0.5, 0.6, 0.7])
}

func testSuppressBlankFilter() throws {
Expand Down