Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 195 additions & 39 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,69 +7,218 @@ import AVFoundation
import CoreML
import Foundation
import Hub
import os
import Tokenizers

open class WhisperKit {
/// Models
public private(set) var modelVariant: ModelVariant = .tiny
public private(set) var modelState: ModelState = .unloaded {
didSet {
modelStateCallback?(oldValue, modelState)
/// Primary entry point for model lifecycle and transcription in WhisperKit.
///
/// `WhisperKit` is intentionally an `open class` to allow subclass customization.
/// The type is annotated `@unchecked Sendable` so instances can be shared across
/// concurrency domains while preserving source compatibility for existing users.
///
/// Internal mutable state is synchronized with an unfair lock.
///
/// - Important: Synchronization only covers `WhisperKit`'s own stored state.
/// Custom dependencies you inject (for example `AudioProcessing` or `TextDecoding`)
/// are responsible for their own thread-safety.
open class WhisperKit: @unchecked Sendable {
// `WhisperKit` is open and mutable, so checked `Sendable` is not available.
// Guarding state with `stateLock` keeps internal shared mutation synchronized.
private let stateLock = OSAllocatedUnfairLock()

private var _modelVariant: ModelVariant = .tiny
/// The model variant currently associated with this instance.
///
/// The value is inferred during tokenizer loading based on model dimensions.
public private(set) var modelVariant: ModelVariant {
get { withStateLock { _modelVariant } }
set { withStateLock { _modelVariant = newValue } }
}

private var _modelState: ModelState = .unloaded
/// Current lifecycle state of the model pipeline.
///
/// Updating this value triggers ``modelStateCallback``.
public private(set) var modelState: ModelState {
get { withStateLock { _modelState } }
set {
let (oldState, newState, callback): (ModelState, ModelState, ModelStateCallback?) = withStateLock {
let oldState = _modelState
_modelState = newValue
return (oldState, newValue, _modelStateCallback)
}
callback?(oldState, newState)
}
}

public var modelCompute: ModelComputeOptions
public var audioInputConfig: AudioInputConfig
private var _modelCompute: ModelComputeOptions
/// Compute-unit preferences for the model components.
public var modelCompute: ModelComputeOptions {
get { withStateLock { _modelCompute } }
set { withStateLock { _modelCompute = newValue } }
}

private var _audioInputConfig: AudioInputConfig
/// Audio input handling configuration used for loading and preprocessing.
public var audioInputConfig: AudioInputConfig {
get { withStateLock { _audioInputConfig } }
set { withStateLock { _audioInputConfig = newValue } }
}

private var _tokenizer: WhisperTokenizer?
/// The tokenizer used by transcription and decoding.
///
/// Setting this property also updates ``textDecoder.tokenizer`` to keep both
/// components in sync.
public var tokenizer: WhisperTokenizer? {
didSet {
// Always sync the tokenizer to the text decoder when set
textDecoder.tokenizer = tokenizer
get { withStateLock { _tokenizer } }
set {
withStateLock {
_tokenizer = newValue
// Always sync the tokenizer to the text decoder when set.
_textDecoder.tokenizer = newValue
}
}
}

/// Protocols
public var audioProcessor: any AudioProcessing
public var featureExtractor: any FeatureExtracting
public var audioEncoder: any AudioEncoding
public var textDecoder: any TextDecoding
public var segmentSeeker: any SegmentSeeking
public var voiceActivityDetector: VoiceActivityDetector?
private var _audioProcessor: any AudioProcessing
/// Audio processing component used by the transcription pipeline.
public var audioProcessor: any AudioProcessing {
get { withStateLock { _audioProcessor } }
set { withStateLock { _audioProcessor = newValue } }
}

private var _featureExtractor: any FeatureExtracting
/// Feature extraction component used by the transcription pipeline.
public var featureExtractor: any FeatureExtracting {
get { withStateLock { _featureExtractor } }
set { withStateLock { _featureExtractor = newValue } }
}

private var _audioEncoder: any AudioEncoding
/// Audio encoder component used by the transcription pipeline.
public var audioEncoder: any AudioEncoding {
get { withStateLock { _audioEncoder } }
set { withStateLock { _audioEncoder = newValue } }
}

private var _textDecoder: any TextDecoding
/// Decoder implementation used for token prediction and language detection.
///
/// When replaced, the current ``tokenizer`` is propagated to the new decoder.
public var textDecoder: any TextDecoding {
get { withStateLock { _textDecoder } }
set {
withStateLock {
_textDecoder = newValue
// Keep decoder state consistent if decoder is swapped post-init.
_textDecoder.tokenizer = _tokenizer
}
}
}

private var _segmentSeeker: any SegmentSeeking
/// Segment seeking implementation used during windowed decoding.
public var segmentSeeker: any SegmentSeeking {
get { withStateLock { _segmentSeeker } }
set { withStateLock { _segmentSeeker = newValue } }
}

private var _voiceActivityDetector: VoiceActivityDetector?
/// Optional voice activity detector used for VAD-based chunking workflows.
public var voiceActivityDetector: VoiceActivityDetector? {
get { withStateLock { _voiceActivityDetector } }
set { withStateLock { _voiceActivityDetector = newValue } }
}

/// Shapes
public static let sampleRate: Int = 16000
public static let hopLength: Int = 160
public static let secondsPerTimeToken = Float(0.02)

/// Progress
public private(set) var currentTimings: TranscriptionTimings
public private(set) var progress = Progress()
private var _currentTimings: TranscriptionTimings
/// Aggregated timings for the most recent model/transcription work.
public private(set) var currentTimings: TranscriptionTimings {
get { withStateLock { _currentTimings } }
set { withStateLock { _currentTimings = newValue } }
}

private var _progress = Progress()
/// Progress for the active transcription workflow.
public private(set) var progress: Progress {
get { withStateLock { _progress } }
set { withStateLock { _progress = newValue } }
}

/// Configuration
public var modelFolder: URL?
public var tokenizerFolder: URL?
public private(set) var useBackgroundDownloadSession: Bool
private var _modelFolder: URL?
/// Local model folder used for loading models when configured.
public var modelFolder: URL? {
get { withStateLock { _modelFolder } }
set { withStateLock { _modelFolder = newValue } }
}

/// Callbacks
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
public var modelStateCallback: ModelStateCallback?
public var transcriptionStateCallback: TranscriptionStateCallback?
private var _tokenizerFolder: URL?
/// Preferred local tokenizer folder used for tokenizer resolution.
public var tokenizerFolder: URL? {
get { withStateLock { _tokenizerFolder } }
set { withStateLock { _tokenizerFolder = newValue } }
}

private var _useBackgroundDownloadSession: Bool
/// Whether background URLSession downloads are enabled for remote assets.
public private(set) var useBackgroundDownloadSession: Bool {
get { withStateLock { _useBackgroundDownloadSession } }
set { withStateLock { _useBackgroundDownloadSession = newValue } }
}

private var _segmentDiscoveryCallback: SegmentDiscoveryCallback?
/// Callback invoked when new segments are discovered during decoding.
public var segmentDiscoveryCallback: SegmentDiscoveryCallback? {
get { withStateLock { _segmentDiscoveryCallback } }
set { withStateLock { _segmentDiscoveryCallback = newValue } }
}

private var _modelStateCallback: ModelStateCallback?
/// Callback invoked whenever ``modelState`` changes.
///
/// The callback receives the previous and new state in transition order.
public var modelStateCallback: ModelStateCallback? {
get { withStateLock { _modelStateCallback } }
set { withStateLock { _modelStateCallback = newValue } }
}

private var _transcriptionStateCallback: TranscriptionStateCallback?
/// Callback invoked for high-level transcription lifecycle updates.
public var transcriptionStateCallback: TranscriptionStateCallback? {
get { withStateLock { _transcriptionStateCallback } }
set { withStateLock { _transcriptionStateCallback = newValue } }
}

public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws {
modelCompute = config.computeOptions ?? ModelComputeOptions()
audioInputConfig = config.audioInputConfig ?? AudioInputConfig()
audioProcessor = config.audioProcessor ?? AudioProcessor()
featureExtractor = config.featureExtractor ?? FeatureExtractor()
audioEncoder = config.audioEncoder ?? AudioEncoder()
textDecoder = config.textDecoder ?? TextDecoder()
_modelCompute = config.computeOptions ?? ModelComputeOptions()
_audioInputConfig = config.audioInputConfig ?? AudioInputConfig()
_audioProcessor = config.audioProcessor ?? AudioProcessor()
_featureExtractor = config.featureExtractor ?? FeatureExtractor()
_audioEncoder = config.audioEncoder ?? AudioEncoder()
_textDecoder = config.textDecoder ?? TextDecoder()
_segmentSeeker = config.segmentSeeker ?? SegmentSeeker()
_voiceActivityDetector = config.voiceActivityDetector
_tokenizerFolder = config.tokenizerFolder ?? config.downloadBase
_useBackgroundDownloadSession = config.useBackgroundDownloadSession
_currentTimings = TranscriptionTimings()
_tokenizer = nil
_modelFolder = nil
_segmentDiscoveryCallback = nil
_modelStateCallback = nil
_transcriptionStateCallback = nil

if let logitsFilters = config.logitsFilters {
textDecoder.logitsFilters = logitsFilters
_textDecoder.logitsFilters = logitsFilters
}

segmentSeeker = config.segmentSeeker ?? SegmentSeeker()
voiceActivityDetector = config.voiceActivityDetector
tokenizerFolder = config.tokenizerFolder ?? config.downloadBase
useBackgroundDownloadSession = config.useBackgroundDownloadSession
currentTimings = TranscriptionTimings()

await Logging.updateLogLevel(config.verbose ? config.logLevel : .none)

try await setupModels(
Expand All @@ -94,6 +243,13 @@ open class WhisperKit {
}
}

@inline(__always)
private func withStateLock<T>(_ body: () throws -> T) rethrows -> T {
try stateLock.withLockUnchecked { _ in
try body()
}
}

public convenience init(
model: String? = nil,
downloadBase: URL? = nil,
Expand Down
46 changes: 46 additions & 0 deletions Tests/WhisperKitTests/Mocks/WhisperTokenizerMock.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2026 Argmax, Inc. All rights reserved.

@testable import WhisperKit

final class WhisperTokenizerMock: WhisperTokenizer {
let specialTokens: SpecialTokens
let allLanguageTokens: Set<Int>

init(specialTokenBegin: Int) {
self.specialTokens = SpecialTokens(
endToken: 1,
englishToken: 2,
noSpeechToken: 3,
noTimestampsToken: 4,
specialTokenBegin: specialTokenBegin,
startOfPreviousToken: 5,
startOfTranscriptToken: 6,
timeTokenBegin: 7,
transcribeToken: 8,
translateToken: 9,
whitespaceToken: 10
)
self.allLanguageTokens = []
}

func encode(text: String) -> [Int] {
[1, 2, 3]
}

func decode(tokens: [Int]) -> String {
"mock text"
}

func convertTokenToId(_ token: String) -> Int? {
nil
}

func convertIdToToken(_ id: Int) -> String? {
"token_\(id)"
}

func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) {
(["mock"], [[1, 2, 3]])
}
}
16 changes: 16 additions & 0 deletions Tests/WhisperKitTests/Utils/LockedStore.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2026 Argmax, Inc. All rights reserved.

import os

final class LockedStore<Value: Sendable>: @unchecked Sendable {
private let lock: OSAllocatedUnfairLock<Value>

init(_ initialValue: Value) {
self.lock = OSAllocatedUnfairLock(initialState: initialValue)
}

func withValue<T: Sendable>(_ body: @Sendable (inout Value) -> T) -> T {
lock.withLock(body)
}
}
Loading