Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 169 additions & 47 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,112 @@ public extension AudioProcessing {
}
}

open class AudioProcessor: NSObject, AudioProcessing {
private var lastInputDevice: DeviceID?
public var audioEngine: AVAudioEngine?
public var audioSamples: ContiguousArray<Float> = []
public var audioEnergy: [(rel: Float, avg: Float, max: Float, min: Float)] = []
public var relativeEnergyWindow: Int = 20
/// `AudioProcessor` participates in concurrent contexts (audio tap callbacks,
/// stream termination handlers, and transcriber reads). State mutations are
/// serialized with `stateLock`, making this reference type safe to share.
open class AudioProcessor: NSObject, AudioProcessing, @unchecked Sendable {
private let stateLock = UnfairLock()
private var lastInputDeviceStorage: DeviceID?
private var audioEngineStorage: AVAudioEngine?
private var audioSamplesStorage: ContiguousArray<Float> = []
private var audioEnergyStorage: [(rel: Float, avg: Float, max: Float, min: Float)] = []
private var relativeEnergyWindowStorage: Int = 20
private var audioBufferCallbackStorage: (([Float]) -> Void)?
private var minBufferLengthStorage = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz
private var isInputSuppressedStorage: Bool = false

public var audioEngine: AVAudioEngine? {
get {
stateLock.withLock {
audioEngineStorage
}
}
set {
stateLock.withLock {
audioEngineStorage = newValue
}
}
}
public var audioSamples: ContiguousArray<Float> {
stateLock.withLock {
audioSamplesStorage
}
Comment thread
naykutguven marked this conversation as resolved.
}

public var audioEnergy: [(rel: Float, avg: Float, max: Float, min: Float)] {
get {
stateLock.withLock {
audioEnergyStorage
}
}
set {
stateLock.withLock {
audioEnergyStorage = newValue
}
}
}
public var relativeEnergyWindow: Int {
get {
stateLock.withLock {
relativeEnergyWindowStorage
}
}
set {
stateLock.withLock {
relativeEnergyWindowStorage = newValue
}
}
}
public var relativeEnergy: [Float] {
return self.audioEnergy.map { $0.rel }
let energySnapshot = stateLock.withLock {
audioEnergyStorage
}
return energySnapshot.map { $0.rel }
}

public var audioBufferCallback: (([Float]) -> Void)?
public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz
public private(set) var isInputSuppressed = false
public private(set) var isInputSuppressed: Bool {
get {
stateLock.withLock {
isInputSuppressedStorage
}
}
set {
stateLock.withLock {
isInputSuppressedStorage = newValue
}
}
}

/// Suppress input buffers by replacing them with silence while keeping timing intact.
public func setInputSuppressed(_ isSuppressed: Bool) {
isInputSuppressed = isSuppressed
}


public var audioBufferCallback: (([Float]) -> Void)? {
get {
stateLock.withLock {
audioBufferCallbackStorage
}
}
set {
stateLock.withLock {
audioBufferCallbackStorage = newValue
}
}
}
public var minBufferLength: Int {
get {
stateLock.withLock {
minBufferLengthStorage
}
}
set {
stateLock.withLock {
minBufferLengthStorage = newValue
Comment thread
naykutguven marked this conversation as resolved.
}
}
}

open func padOrTrim(fromArray audioArray: [Float], startAt startIndex: Int, toLength frameLength: Int) -> (any AudioProcessorOutputType)? {
return AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: startIndex, toLength: frameLength, saveSegment: false)
}
Expand Down Expand Up @@ -905,23 +992,35 @@ public extension AudioProcessor {
/// We have a new buffer, process and store it.
/// NOTE: Assumes audio is 16khz mono
func processBuffer(_ buffer: [Float]) {
audioSamples.append(contentsOf: buffer)

// Find the lowest average energy of the last 20 buffers ~2 seconds
let minAvgEnergy = self.audioEnergy.suffix(20).reduce(Float.infinity) { min($0, $1.avg) }
let relativeEnergy = Self.calculateRelativeEnergy(of: buffer, relativeTo: minAvgEnergy)

// Update energy for buffers with valid data
let signalEnergy = Self.calculateEnergy(of: buffer)
let newEnergy = (relativeEnergy, signalEnergy.avg, signalEnergy.max, signalEnergy.min)
self.audioEnergy.append(newEnergy)
var newEnergy = (rel: Float.zero, avg: Float.zero, max: Float.zero, min: Float.zero)
var currentAudioSampleCount = 0
var shouldLog = false
var callback: (([Float]) -> Void)?

stateLock.withLock {
audioSamplesStorage.append(contentsOf: buffer)

// Find the lowest average energy of the last 20 buffers ~2 seconds.
let minAvgEnergy = audioEnergyStorage.suffix(20).reduce(Float.infinity) { min($0, $1.avg) }
let referenceEnergy = minAvgEnergy.isFinite ? minAvgEnergy : nil
let relativeEnergy = Self.calculateRelativeEnergy(of: buffer, relativeTo: referenceEnergy)

// Update energy for buffers with valid data
newEnergy = (relativeEnergy, signalEnergy.avg, signalEnergy.max, signalEnergy.min)
audioEnergyStorage.append(newEnergy)

currentAudioSampleCount = audioSamplesStorage.count
let logStride = max(1, minBufferLengthStorage * Int(relativeEnergyWindowStorage))
shouldLog = currentAudioSampleCount % logStride == 0
callback = audioBufferCallbackStorage
}
Comment thread
naykutguven marked this conversation as resolved.

// Call the callback with the new buffer
audioBufferCallback?(buffer)
// Call the callback with the new buffer outside the lock to avoid re-entrant lock attempts.
callback?(buffer)

// Print the current size of the audio buffer
if self.audioSamples.count % (minBufferLength * Int(relativeEnergyWindow)) == 0 {
Logging.debug("Current audio size: \(self.audioSamples.count) samples, most recent buffer: \(buffer.count) samples, most recent energy: \(newEnergy)")
if shouldLog {
Logging.debug("Current audio size: \(currentAudioSampleCount) samples, most recent buffer: \(buffer.count) samples, most recent energy: \(newEnergy)")
}
}

Expand Down Expand Up @@ -997,7 +1096,7 @@ public extension AudioProcessor {
throw WhisperError.audioProcessingFailed("Failed to create audio converter")
}

let bufferSize = AVAudioFrameCount(minBufferLength) // 100ms - 400ms supported
let bufferSize = AVAudioFrameCount(stateLock.withLock { minBufferLengthStorage }) // 100ms - 400ms supported
inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { [weak self] (buffer: AVAudioPCMBuffer, _: AVAudioTime) in
guard let self = self else { return }
var buffer = buffer
Expand All @@ -1022,23 +1121,27 @@ public extension AudioProcessor {
}

func purgeAudioSamples(keepingLast keep: Int) {
if audioSamples.count > keep {
audioSamples.removeFirst(audioSamples.count - keep)
stateLock.withLock {
if audioSamplesStorage.count > keep {
audioSamplesStorage.removeFirst(audioSamplesStorage.count - keep)
}
}
}

func startRecordingLive(inputDeviceID: DeviceID? = nil, callback: (([Float]) -> Void)? = nil) throws {
audioSamples = []
audioEnergy = []
stateLock.withLock {
audioSamplesStorage = []
audioEnergyStorage = []
}

try? setupAudioSessionForDevice()

audioEngine = try setupEngine(inputDeviceID: inputDeviceID)

// Set the callback
audioBufferCallback = callback

lastInputDevice = inputDeviceID
let engine = try setupEngine(inputDeviceID: inputDeviceID)
stateLock.withLock {
audioEngineStorage = engine
audioBufferCallbackStorage = callback
lastInputDeviceStorage = inputDeviceID
}
}

/// Starts live audio recording and returns an async stream that yields sample buffers.
Expand All @@ -1047,9 +1150,7 @@ public extension AudioProcessor {
let (stream, continuation) = AsyncThrowingStream<[Float], Error>.makeStream(bufferingPolicy: .unbounded)

continuation.onTermination = { [weak self] _ in
guard let self = self else { return }
self.audioBufferCallback = nil
self.stopRecording()
self?.stopRecording()
}

do {
Expand All @@ -1066,28 +1167,50 @@ public extension AudioProcessor {
func resumeRecordingLive(inputDeviceID: DeviceID? = nil, callback: (([Float]) -> Void)? = nil) throws {
try? setupAudioSessionForDevice()

if inputDeviceID == lastInputDevice {
try audioEngine?.start()
let engine = stateLock.withLock { () -> AVAudioEngine? in
guard inputDeviceID == lastInputDeviceStorage else {
return nil
}
return audioEngineStorage
}

if let engine {
try engine.start()
} else {
audioEngine = try setupEngine(inputDeviceID: inputDeviceID)
let engine = try setupEngine(inputDeviceID: inputDeviceID)
stateLock.withLock {
audioEngineStorage = engine
lastInputDeviceStorage = inputDeviceID
}
}

// Set the callback only if the provided callback is not nil
if let callback = callback {
audioBufferCallback = callback
stateLock.withLock {
audioBufferCallbackStorage = callback
}
}
}

func pauseRecording() {
audioEngine?.pause()
let engine = stateLock.withLock { audioEngineStorage }
engine?.pause()
}

func stopRecording() {
guard let engine = audioEngine else { return }
let engine = stateLock.withLock { () -> AVAudioEngine? in
let engine = audioEngineStorage
audioEngineStorage = nil
audioBufferCallbackStorage = nil
return engine
}

guard let engine = engine else { return }

// Remove tap from the input node explicitly.
engine.inputNode.removeTap(onBus: 0)

// Remove the tap on any attached node
engine.attachedNodes.forEach { node in
node.removeTap(onBus: 0)
}
Expand All @@ -1096,12 +1219,11 @@ public extension AudioProcessor {
// This helps prevent lingering input connections across repeated start/stop cycles.
engine.disconnectNodeInput(engine.inputNode)

// Stop the audio engine
engine.stop()

// Reset clears the engine/node state so a subsequent start builds a fresh graph.
engine.reset()

audioEngine = nil
}

func suppressInputIfNeeded(_ buffer: inout [Float]) {
Expand Down
12 changes: 12 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,18 @@ final class UnitTests: XCTestCase {
let energyVeryLoud = AudioProcessor.calculateAverageEnergy(of: veryLoudNoise)
XCTAssertGreaterThan(energyVeryLoud, energyLoud, "Audio energy is not very loud")
}

func testProcessBufferFirstBufferProducesFiniteNormalizedRelativeEnergy() throws {
let audioProcessor = AudioProcessor()
let firstBuffer = [Float](repeating: 0.05, count: 1600)

audioProcessor.processBuffer(firstBuffer)

let firstEnergy = try XCTUnwrap(audioProcessor.audioEnergy.first, "Expected processBuffer to store a first energy entry")
XCTAssertTrue(firstEnergy.rel.isFinite, "First relative energy should always be finite")
XCTAssertGreaterThanOrEqual(firstEnergy.rel, 0, "Relative energy should be clamped to the lower bound")
XCTAssertLessThanOrEqual(firstEnergy.rel, 1, "Relative energy should be clamped to the upper bound")
}

// MARK: - Protocol Conformance Tests

Expand Down