Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions Sources/TTSKit/TTSKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import os

// MARK: - Callback Typealiases

/// A closure invoked when audio samples are aligned with playback.
///
/// Use this for playback-reactive features such as metering or lip sync.
public typealias PlaybackCallback = (@Sendable ([Float]) -> Void)?

// MARK: - TTSKit

/// Generic TTS orchestrator: text chunking, concurrent generation, crossfade, and audio playback.
Expand Down Expand Up @@ -997,11 +1002,14 @@ open class TTSKit: @unchecked Sendable {
language: String? = nil,
options: GenerationOptions = GenerationOptions(),
playbackStrategy: PlaybackStrategy = .auto,
callback: SpeechCallback = nil
callback: SpeechCallback = nil,
playbackCallback: PlaybackCallback = nil
) async throws -> SpeechResult {
var playOptions = options

let audioOut = audioOutput
audioOut.playbackCallback = playbackCallback
defer { audioOut.playbackCallback = nil }
let maxTokens = playOptions.maxNewTokens

// Pre-resolve audio format from the task so the playback closure doesn't
Expand Down Expand Up @@ -1062,6 +1070,25 @@ open class TTSKit: @unchecked Sendable {
return result
}

open func play(
text: String,
voice: String?,
language: String?,
options: GenerationOptions,
playbackStrategy: PlaybackStrategy,
callback: SpeechCallback
) async throws -> SpeechResult {
try await play(
text: text,
voice: voice,
language: language,
options: options,
playbackStrategy: playbackStrategy,
callback: callback,
playbackCallback: nil
)
}

// MARK: - Qwen3-typed convenience API

/// Build a prompt cache using typed Qwen3 speaker and language enums.
Expand Down Expand Up @@ -1128,15 +1155,36 @@ open class TTSKit: @unchecked Sendable {
language: Qwen3Language = .english,
options: GenerationOptions = GenerationOptions(),
playbackStrategy: PlaybackStrategy = .auto,
callback: SpeechCallback = nil
callback: SpeechCallback = nil,
playbackCallback: PlaybackCallback = nil
) async throws -> SpeechResult {
try await play(
text: text,
voice: speaker.rawValue,
language: language.rawValue,
options: options,
playbackStrategy: playbackStrategy,
callback: callback
callback: callback,
playbackCallback: playbackCallback
)
}

open func play(
text: String,
speaker: Qwen3Speaker,
language: Qwen3Language = .english,
options: GenerationOptions = GenerationOptions(),
playbackStrategy: PlaybackStrategy = .auto,
callback: SpeechCallback = nil
) async throws -> SpeechResult {
try await play(
text: text,
speaker: speaker,
language: language,
options: options,
playbackStrategy: playbackStrategy,
callback: callback,
playbackCallback: nil
)
}
}
Expand Down
21 changes: 20 additions & 1 deletion Sources/TTSKit/Utilities/AudioOutput.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class AudioOutput: @unchecked Sendable {
private var audioEngine: AVAudioEngine?
private var playerNode: AVAudioPlayerNode?
private var engineStartDeferred: Bool = false
var playbackCallback: PlaybackCallback = nil
public private(set) var isOutputSuppressed = false

/// Pre-buffer threshold in seconds. `nil` means not yet configured - frames
/// accumulate in `pendingFrames` until `setBufferDuration` is called.
Expand Down Expand Up @@ -112,6 +114,15 @@ public class AudioOutput: @unchecked Sendable {
audioFormat = format
}

/// Suppress or restore audible playback output.
///
/// This updates the active engine immediately when playback is in progress and
/// also becomes the default for the next `startPlayback()` call.
public func setOutputSuppressed(_ isSuppressed: Bool) {
isOutputSuppressed = isSuppressed
audioEngine?.mainMixerNode.outputVolume = isSuppressed ? 0 : 1
}

/// Current playback position in seconds, based on the audio engine's render timeline.
/// Returns 0 if the player is not active, no audio has been scheduled yet, or
/// the player hasn't started rendering.
Expand Down Expand Up @@ -521,6 +532,7 @@ public class AudioOutput: @unchecked Sendable {

self.audioEngine = engine
self.playerNode = player
setOutputSuppressed(isOutputSuppressed)
}

/// Enqueue a chunk of audio samples for playback.
Expand Down Expand Up @@ -672,7 +684,14 @@ public class AudioOutput: @unchecked Sendable {
expectedPlaybackEnd = max(expectedPlaybackEnd, now) + bufferSeconds
scheduledAudioDuration += bufferSeconds

player.scheduleBuffer(buffer)
if let playbackCallback = self.playbackCallback {
let playedSamples = samples
player.scheduleBuffer(buffer, completionCallbackType: .dataPlayedBack) { _ in
playbackCallback(playedSamples)
}
} else {
player.scheduleBuffer(buffer)
}
}

/// Stop playback and tear down the audio engine.
Expand Down
22 changes: 22 additions & 0 deletions Tests/TTSKitTests/TTSKitIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,28 @@ final class TTSKitIntegrationTests: XCTestCase {
XCTAssertEqual(result.sampleRate, 24000)
}

/// Playback callback should fire with non-empty samples during real-time playback.
func testPlaybackCallbackReceivesSamples() async throws {
let tts = try await makeCachedTTS(seed: 42)
tts.audioOutput.setOutputSuppressed(true)
let callbackExpectation = expectation(description: "Playback callback fired")
callbackExpectation.assertForOverFulfill = false

_ = try await tts.play(
text: "Playback callback smoke test.",
speaker: .ryan,
language: .english,
options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 80),
playbackStrategy: .generateFirst,
playbackCallback: { samples in
guard samples.isEmpty == false else { return }
callbackExpectation.fulfill()
}
)

await fulfillment(of: [callbackExpectation], timeout: 5.0)
}

// MARK: - Performance

/// Verify timings are populated and the generation loop completed within a reasonable ceiling.
Expand Down
9 changes: 9 additions & 0 deletions Tests/TTSKitTests/TTSKitUnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ final class TTSKitUnitTests: XCTestCase {
XCTAssertNotNil(opts.concurrentWorkerCount)
}

func testOutputSuppressionToggleState() {
let output = AudioOutput()
XCTAssertFalse(output.isOutputSuppressed, "Suppression should be disabled by default")
output.setOutputSuppressed(true)
XCTAssertTrue(output.isOutputSuppressed, "Suppression should be enabled after setting true")
output.setOutputSuppressed(false)
XCTAssertFalse(output.isOutputSuppressed, "Suppression should be disabled after setting false")
}

func testDownloadPatterns() {
let config = TTSKitConfig()
let patterns = config.downloadPatterns
Expand Down
Loading