diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 7db6a4b5..cfffec1f 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -7,33 +7,130 @@ import AVFoundation import CoreML import Foundation import Hub +import os import Tokenizers -open class WhisperKit { - /// Models - public private(set) var modelVariant: ModelVariant = .tiny - public private(set) var modelState: ModelState = .unloaded { - didSet { - modelStateCallback?(oldValue, modelState) +/// Primary entry point for model lifecycle and transcription in WhisperKit. +/// +/// `WhisperKit` is intentionally an `open class` to allow subclass customization. +/// The type is annotated `@unchecked Sendable` so instances can be shared across +/// concurrency domains while preserving source compatibility for existing users. +/// +/// Internal mutable state is synchronized with an unfair lock. +/// +/// - Important: Synchronization only covers `WhisperKit`'s own stored state. +/// Custom dependencies you inject (for example `AudioProcessing` or `TextDecoding`) +/// are responsible for their own thread-safety. +open class WhisperKit: @unchecked Sendable { + // `WhisperKit` is open and mutable, so checked `Sendable` is not available. + // Guarding state with `stateLock` keeps internal shared mutation synchronized. + private let stateLock = OSAllocatedUnfairLock() + + private var _modelVariant: ModelVariant = .tiny + /// The model variant currently associated with this instance. + /// + /// The value is inferred during tokenizer loading based on model dimensions. + public private(set) var modelVariant: ModelVariant { + get { withStateLock { _modelVariant } } + set { withStateLock { _modelVariant = newValue } } + } + + private var _modelState: ModelState = .unloaded + /// Current lifecycle state of the model pipeline. + /// + /// Updating this value triggers ``modelStateCallback``. + public private(set) var modelState: ModelState { + get { withStateLock { _modelState } } + set { + let (oldState, newState, callback): (ModelState, ModelState, ModelStateCallback?) = withStateLock { + let oldState = _modelState + _modelState = newValue + return (oldState, newValue, _modelStateCallback) + } + callback?(oldState, newState) } } - public var modelCompute: ModelComputeOptions - public var audioInputConfig: AudioInputConfig + private var _modelCompute: ModelComputeOptions + /// Compute-unit preferences for the model components. + public var modelCompute: ModelComputeOptions { + get { withStateLock { _modelCompute } } + set { withStateLock { _modelCompute = newValue } } + } + + private var _audioInputConfig: AudioInputConfig + /// Audio input handling configuration used for loading and preprocessing. + public var audioInputConfig: AudioInputConfig { + get { withStateLock { _audioInputConfig } } + set { withStateLock { _audioInputConfig = newValue } } + } + + private var _tokenizer: WhisperTokenizer? + /// The tokenizer used by transcription and decoding. + /// + /// Setting this property also updates ``textDecoder.tokenizer`` to keep both + /// components in sync. public var tokenizer: WhisperTokenizer? { - didSet { - // Always sync the tokenizer to the text decoder when set - textDecoder.tokenizer = tokenizer + get { withStateLock { _tokenizer } } + set { + withStateLock { + _tokenizer = newValue + // Always sync the tokenizer to the text decoder when set. + _textDecoder.tokenizer = newValue + } } } /// Protocols - public var audioProcessor: any AudioProcessing - public var featureExtractor: any FeatureExtracting - public var audioEncoder: any AudioEncoding - public var textDecoder: any TextDecoding - public var segmentSeeker: any SegmentSeeking - public var voiceActivityDetector: VoiceActivityDetector? + private var _audioProcessor: any AudioProcessing + /// Audio processing component used by the transcription pipeline. + public var audioProcessor: any AudioProcessing { + get { withStateLock { _audioProcessor } } + set { withStateLock { _audioProcessor = newValue } } + } + + private var _featureExtractor: any FeatureExtracting + /// Feature extraction component used by the transcription pipeline. + public var featureExtractor: any FeatureExtracting { + get { withStateLock { _featureExtractor } } + set { withStateLock { _featureExtractor = newValue } } + } + + private var _audioEncoder: any AudioEncoding + /// Audio encoder component used by the transcription pipeline. + public var audioEncoder: any AudioEncoding { + get { withStateLock { _audioEncoder } } + set { withStateLock { _audioEncoder = newValue } } + } + + private var _textDecoder: any TextDecoding + /// Decoder implementation used for token prediction and language detection. + /// + /// When replaced, the current ``tokenizer`` is propagated to the new decoder. + public var textDecoder: any TextDecoding { + get { withStateLock { _textDecoder } } + set { + withStateLock { + _textDecoder = newValue + // Keep decoder state consistent if decoder is swapped post-init. + _textDecoder.tokenizer = _tokenizer + } + } + } + + private var _segmentSeeker: any SegmentSeeking + /// Segment seeking implementation used during windowed decoding. + public var segmentSeeker: any SegmentSeeking { + get { withStateLock { _segmentSeeker } } + set { withStateLock { _segmentSeeker = newValue } } + } + + private var _voiceActivityDetector: VoiceActivityDetector? + /// Optional voice activity detector used for VAD-based chunking workflows. + public var voiceActivityDetector: VoiceActivityDetector? { + get { withStateLock { _voiceActivityDetector } } + set { withStateLock { _voiceActivityDetector = newValue } } + } /// Shapes public static let sampleRate: Int = 16000 @@ -41,35 +138,87 @@ open class WhisperKit { public static let secondsPerTimeToken = Float(0.02) /// Progress - public private(set) var currentTimings: TranscriptionTimings - public private(set) var progress = Progress() + private var _currentTimings: TranscriptionTimings + /// Aggregated timings for the most recent model/transcription work. + public private(set) var currentTimings: TranscriptionTimings { + get { withStateLock { _currentTimings } } + set { withStateLock { _currentTimings = newValue } } + } + + private var _progress = Progress() + /// Progress for the active transcription workflow. + public private(set) var progress: Progress { + get { withStateLock { _progress } } + set { withStateLock { _progress = newValue } } + } /// Configuration - public var modelFolder: URL? - public var tokenizerFolder: URL? - public private(set) var useBackgroundDownloadSession: Bool + private var _modelFolder: URL? + /// Local model folder used for loading models when configured. + public var modelFolder: URL? { + get { withStateLock { _modelFolder } } + set { withStateLock { _modelFolder = newValue } } + } - /// Callbacks - public var segmentDiscoveryCallback: SegmentDiscoveryCallback? - public var modelStateCallback: ModelStateCallback? - public var transcriptionStateCallback: TranscriptionStateCallback? + private var _tokenizerFolder: URL? + /// Preferred local tokenizer folder used for tokenizer resolution. + public var tokenizerFolder: URL? { + get { withStateLock { _tokenizerFolder } } + set { withStateLock { _tokenizerFolder = newValue } } + } + + private var _useBackgroundDownloadSession: Bool + /// Whether background URLSession downloads are enabled for remote assets. + public private(set) var useBackgroundDownloadSession: Bool { + get { withStateLock { _useBackgroundDownloadSession } } + set { withStateLock { _useBackgroundDownloadSession = newValue } } + } + + private var _segmentDiscoveryCallback: SegmentDiscoveryCallback? + /// Callback invoked when new segments are discovered during decoding. + public var segmentDiscoveryCallback: SegmentDiscoveryCallback? { + get { withStateLock { _segmentDiscoveryCallback } } + set { withStateLock { _segmentDiscoveryCallback = newValue } } + } + + private var _modelStateCallback: ModelStateCallback? + /// Callback invoked whenever ``modelState`` changes. + /// + /// The callback receives the previous and new state in transition order. + public var modelStateCallback: ModelStateCallback? { + get { withStateLock { _modelStateCallback } } + set { withStateLock { _modelStateCallback = newValue } } + } + + private var _transcriptionStateCallback: TranscriptionStateCallback? + /// Callback invoked for high-level transcription lifecycle updates. + public var transcriptionStateCallback: TranscriptionStateCallback? { + get { withStateLock { _transcriptionStateCallback } } + set { withStateLock { _transcriptionStateCallback = newValue } } + } public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws { - modelCompute = config.computeOptions ?? ModelComputeOptions() - audioInputConfig = config.audioInputConfig ?? AudioInputConfig() - audioProcessor = config.audioProcessor ?? AudioProcessor() - featureExtractor = config.featureExtractor ?? FeatureExtractor() - audioEncoder = config.audioEncoder ?? AudioEncoder() - textDecoder = config.textDecoder ?? TextDecoder() + _modelCompute = config.computeOptions ?? ModelComputeOptions() + _audioInputConfig = config.audioInputConfig ?? AudioInputConfig() + _audioProcessor = config.audioProcessor ?? AudioProcessor() + _featureExtractor = config.featureExtractor ?? FeatureExtractor() + _audioEncoder = config.audioEncoder ?? AudioEncoder() + _textDecoder = config.textDecoder ?? TextDecoder() + _segmentSeeker = config.segmentSeeker ?? SegmentSeeker() + _voiceActivityDetector = config.voiceActivityDetector + _tokenizerFolder = config.tokenizerFolder ?? config.downloadBase + _useBackgroundDownloadSession = config.useBackgroundDownloadSession + _currentTimings = TranscriptionTimings() + _tokenizer = nil + _modelFolder = nil + _segmentDiscoveryCallback = nil + _modelStateCallback = nil + _transcriptionStateCallback = nil + if let logitsFilters = config.logitsFilters { - textDecoder.logitsFilters = logitsFilters + _textDecoder.logitsFilters = logitsFilters } - - segmentSeeker = config.segmentSeeker ?? SegmentSeeker() - voiceActivityDetector = config.voiceActivityDetector - tokenizerFolder = config.tokenizerFolder ?? config.downloadBase - useBackgroundDownloadSession = config.useBackgroundDownloadSession - currentTimings = TranscriptionTimings() + await Logging.updateLogLevel(config.verbose ? config.logLevel : .none) try await setupModels( @@ -94,6 +243,13 @@ open class WhisperKit { } } + @inline(__always) + private func withStateLock(_ body: () throws -> T) rethrows -> T { + try stateLock.withLockUnchecked { _ in + try body() + } + } + public convenience init( model: String? = nil, downloadBase: URL? = nil, diff --git a/Tests/WhisperKitTests/Mocks/WhisperTokenizerMock.swift b/Tests/WhisperKitTests/Mocks/WhisperTokenizerMock.swift new file mode 100644 index 00000000..9f3596f4 --- /dev/null +++ b/Tests/WhisperKitTests/Mocks/WhisperTokenizerMock.swift @@ -0,0 +1,46 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +@testable import WhisperKit + +final class WhisperTokenizerMock: WhisperTokenizer { + let specialTokens: SpecialTokens + let allLanguageTokens: Set + + init(specialTokenBegin: Int) { + self.specialTokens = SpecialTokens( + endToken: 1, + englishToken: 2, + noSpeechToken: 3, + noTimestampsToken: 4, + specialTokenBegin: specialTokenBegin, + startOfPreviousToken: 5, + startOfTranscriptToken: 6, + timeTokenBegin: 7, + transcribeToken: 8, + translateToken: 9, + whitespaceToken: 10 + ) + self.allLanguageTokens = [] + } + + func encode(text: String) -> [Int] { + [1, 2, 3] + } + + func decode(tokens: [Int]) -> String { + "mock text" + } + + func convertTokenToId(_ token: String) -> Int? { + nil + } + + func convertIdToToken(_ id: Int) -> String? { + "token_\(id)" + } + + func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) { + (["mock"], [[1, 2, 3]]) + } +} diff --git a/Tests/WhisperKitTests/Utils/LockedStore.swift b/Tests/WhisperKitTests/Utils/LockedStore.swift new file mode 100644 index 00000000..99961239 --- /dev/null +++ b/Tests/WhisperKitTests/Utils/LockedStore.swift @@ -0,0 +1,16 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import os + +final class LockedStore: @unchecked Sendable { + private let lock: OSAllocatedUnfairLock + + init(_ initialValue: Value) { + self.lock = OSAllocatedUnfairLock(initialState: initialValue) + } + + func withValue(_ body: @Sendable (inout Value) -> T) -> T { + lock.withLock(body) + } +} diff --git a/Tests/WhisperKitTests/WhisperKitTests.swift b/Tests/WhisperKitTests/WhisperKitTests.swift new file mode 100644 index 00000000..b1403784 --- /dev/null +++ b/Tests/WhisperKitTests/WhisperKitTests.swift @@ -0,0 +1,102 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation +import os +@testable import WhisperKit +import XCTest + +final class WhisperKitTests: XCTestCase { + private func makeUnloadedWhisperKit() async throws -> WhisperKit { + let config = WhisperKitConfig( + verbose: false, + logLevel: .error, + load: false, + download: false + ) + return try await WhisperKit(config) + } + + func testTokenizerSetterSynchronizesTextDecoderTokenizer() async throws { + let whisperKit = try await makeUnloadedWhisperKit() + let tokenizer = WhisperTokenizerMock(specialTokenBegin: 1000) + + whisperKit.tokenizer = tokenizer + + let decoderTokenizer = try XCTUnwrap(whisperKit.textDecoder.tokenizer as? WhisperTokenizerMock) + XCTAssertTrue(decoderTokenizer === tokenizer, "Setting tokenizer should also update textDecoder.tokenizer") + } + + func testTextDecoderSetterPreservesExistingTokenizer() async throws { + let whisperKit = try await makeUnloadedWhisperKit() + let tokenizer = WhisperTokenizerMock(specialTokenBegin: 1000) + whisperKit.tokenizer = tokenizer + + let replacementDecoder = TextDecoder() + whisperKit.textDecoder = replacementDecoder + + let replacementTokenizer = try XCTUnwrap(replacementDecoder.tokenizer as? WhisperTokenizerMock) + XCTAssertTrue(replacementTokenizer === tokenizer, "Replacing textDecoder should preserve existing tokenizer") + } + + func testModelStateCallbackTransitionOrderDuringUnload() async throws { + let whisperKit = try await makeUnloadedWhisperKit() + let callbackExpectation = expectation(description: "Model state callback fired for unload transition") + callbackExpectation.expectedFulfillmentCount = 2 + + let transitions = LockedStore<[(old: String, new: String)]>([]) + whisperKit.modelStateCallback = { oldState, newState in + let oldDescription = oldState?.description ?? "nil" + let newDescription = newState.description + transitions.withValue { + $0.append((old: oldDescription, new: newDescription)) + } + callbackExpectation.fulfill() + } + + await whisperKit.unloadModels() + await fulfillment(of: [callbackExpectation], timeout: 1) + + let snapshot = transitions.withValue { $0 } + XCTAssertEqual(snapshot.count, 2) + XCTAssertEqual(snapshot[0].old, ModelState.unloaded.description) + XCTAssertEqual(snapshot[0].new, ModelState.unloading.description) + XCTAssertEqual(snapshot[1].old, ModelState.unloading.description) + XCTAssertEqual(snapshot[1].new, ModelState.unloaded.description) + } + + func testConcurrentWhisperKitPropertyAccess() async throws { + let whisperKit = try await makeUnloadedWhisperKit() + let iterations = 200 + + await withTaskGroup(of: Void.self) { taskGroup in + for index in 0..