Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -680,15 +680,21 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
// MARK: Non-inference

// Update predicted token as current
let logits = languageLogitsFilter.filterLogits(decoderOutput.logits!, withTokens: currentTokens)
guard let outputLogits = decoderOutput.logits else {
throw WhisperError.decodingLogitsFailed("Logits unavailable in language detection output")
}
let logits = languageLogitsFilter.filterLogits(outputLogits, withTokens: currentTokens)

// MARK: Sampling

let samplingStartTime = Date()

let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)

nextToken = sampleResult.tokens.last!
guard let sampledToken = sampleResult.tokens.last else {
throw WhisperError.decodingFailed("Language detection sampler returned empty tokens")
}
nextToken = sampledToken
logProbs = sampleResult.logProbs

let samplingTime = Date().timeIntervalSince(samplingStartTime)
Expand Down Expand Up @@ -744,7 +750,10 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
let prefilledIndex = decoderInputs.cacheLength[0].intValue
let initialPromptIndex = decoderInputs.initialPrompt.count
var currentTokens: [Int] = decoderInputs.initialPrompt
var nextToken: Int = decoderInputs.initialPrompt.last!
guard let lastPromptToken = decoderInputs.initialPrompt.last else {
throw WhisperError.prepareDecoderInputsFailed("Initial prompt is empty")
}
var nextToken: Int = lastPromptToken
var logProbs: [Float] = Array(repeating: 0, count: currentTokens.count)

// Logits filters
Expand Down Expand Up @@ -826,7 +835,10 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
let nonInferenceStartTime = Date()

// Update predicted token as current
var logits = decoderOutput.logits!
guard let outputLogits = decoderOutput.logits else {
throw WhisperError.decodingLogitsFailed("Logits unavailable in decoder output")
}
var logits = outputLogits
for filter in logitsFilters {
logits = filter.filterLogits(logits, withTokens: currentTokens)
}
Expand All @@ -840,8 +852,14 @@ open class TextDecoder: TextDecoding, WhisperMLModel {

let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)

nextToken = sampleResult.tokens.last!
let nextTokenLogProb = sampleResult.logProbs.last!
guard let sampledToken = sampleResult.tokens.last else {
throw WhisperError.decodingFailed("Token sampler returned empty tokens")
}
guard let sampledLogProb = sampleResult.logProbs.last else {
throw WhisperError.decodingFailed("Token sampler returned empty logProbs")
}
nextToken = sampledToken
let nextTokenLogProb = sampledLogProb

Logging.debug("Predicted next tokenIndex: \(tokenIndex + 1), token: \(nextToken), text: \(tokenizer.decode(tokens: [nextToken]))")

Expand Down