diff --git a/flutter/sherpa_onnx/lib/src/offline_recognizer.dart b/flutter/sherpa_onnx/lib/src/offline_recognizer.dart index ab7506f667..8d127cbca4 100644 --- a/flutter/sherpa_onnx/lib/src/offline_recognizer.dart +++ b/flutter/sherpa_onnx/lib/src/offline_recognizer.dart @@ -736,6 +736,7 @@ class OfflineRecognizerResult { OfflineRecognizerResult({ required this.text, required this.tokens, + required this.tokenLogProbs, required this.timestamps, required this.lang, required this.emotion, @@ -746,6 +747,10 @@ class OfflineRecognizerResult { return OfflineRecognizerResult( text: json['text'] as String? ?? '', tokens: (json['tokens'] as List?)?.map((e) => e as String).toList() ?? [], + tokenLogProbs: (json['token_log_probs'] as List?) + ?.map((e) => (e as num).toDouble()) + .toList() ?? + [], timestamps: (json['timestamps'] as List?) ?.map((e) => (e as num).toDouble()) @@ -759,12 +764,13 @@ class OfflineRecognizerResult { @override String toString() { - return 'OfflineRecognizerResult(text: $text, tokens: $tokens, timestamps: $timestamps, lang: $lang, emotion: $emotion, event: $event)'; + return 'OfflineRecognizerResult(text: $text, tokens: $tokens, tokenLogProbs: $tokenLogProbs, timestamps: $timestamps, lang: $lang, emotion: $emotion, event: $event)'; } Map toJson() => { 'text': text, 'tokens': tokens, + 'token_log_probs': tokenLogProbs, 'timestamps': timestamps, 'lang': lang, 'emotion': emotion, @@ -773,6 +779,7 @@ class OfflineRecognizerResult { final String text; final List tokens; + final List tokenLogProbs; final List timestamps; final String lang; final String emotion; @@ -793,31 +800,32 @@ class OfflineRecognizer { /// method of the returned instance to avoid memory leak. factory OfflineRecognizer(OfflineRecognizerConfig config) { - final c = convertConfig(config); - if (SherpaOnnxBindings.createOfflineRecognizer == null) { throw Exception("Please initialize sherpa-onnx first"); } - final ptr = SherpaOnnxBindings.createOfflineRecognizer?.call(c) ?? nullptr; + final c = convertConfig(config); - if (ptr == nullptr) { - throw Exception( - "Failed to create offline recognizer. Please check your config", - ); + try { + final ptr = SherpaOnnxBindings.createOfflineRecognizer!.call(c); + if (ptr == nullptr) { + throw Exception( + "Failed to create offline recognizer. Please check your config", + ); + } + return OfflineRecognizer._(ptr: ptr, config: config); + } finally { + freeConfig(c); } - - freeConfig(c); - - return OfflineRecognizer._(ptr: ptr, config: config); } void setConfig(OfflineRecognizerConfig config) { final c = convertConfig(config); - - SherpaOnnxBindings.offlineRecognizerSetConfig?.call(ptr, c); - - freeConfig(c); + try { + SherpaOnnxBindings.offlineRecognizerSetConfig?.call(ptr, c); + } finally { + freeConfig(c); + } // we don't update this.config } @@ -1033,6 +1041,7 @@ class OfflineRecognizer { return OfflineRecognizerResult( text: '', tokens: [], + tokenLogProbs: [], timestamps: [], lang: '', emotion: '', @@ -1044,13 +1053,8 @@ class OfflineRecognizer { SherpaOnnxBindings.destroyOfflineStreamResultJson?.call(json); - return OfflineRecognizerResult( - text: parsedJson['text'], - tokens: List.from(parsedJson['tokens']), - timestamps: List.from(parsedJson['timestamps']), - lang: parsedJson['lang'], - emotion: parsedJson['emotion'], - event: parsedJson['event'], + return OfflineRecognizerResult.fromJson( + parsedJson as Map, ); } diff --git a/flutter/sherpa_onnx/lib/src/offline_stream.dart b/flutter/sherpa_onnx/lib/src/offline_stream.dart index 0b6f9c8667..66f95772a6 100644 --- a/flutter/sherpa_onnx/lib/src/offline_stream.dart +++ b/flutter/sherpa_onnx/lib/src/offline_stream.dart @@ -33,5 +33,49 @@ class OfflineStream { calloc.free(p); } + Map>? getVocabLogProbs() { + final getFunc = SherpaOnnxBindings.getOfflineStreamVocabLogProbs; + final destroyFunc = SherpaOnnxBindings.destroyVocabLogProbs; + + if (getFunc == null || destroyFunc == null) { + return null; + } + + final vocabPtr = getFunc(ptr); + if (vocabPtr == nullptr) { + return null; + } + + final vocabLogProbs = vocabPtr.ref; + final numTokens = vocabLogProbs.numTokens; + final vocabSize = vocabLogProbs.vocabSize; + + // Defensive validation for native values + if (numTokens < 0 || + vocabSize < 0 || + numTokens > 10000 || + vocabSize > 100000) { + destroyFunc(vocabPtr); + return null; + } + + final Map> result = {}; + + for (int tokenIdx = 0; tokenIdx < numTokens; tokenIdx++) { + final List tokenProbs = []; + + for (int vocabIdx = 0; vocabIdx < vocabSize; vocabIdx++) { + final index = tokenIdx * vocabSize + vocabIdx; + final logProb = vocabLogProbs.logProbs[index]; + tokenProbs.add(logProb); + } + + result['token_$tokenIdx'] = tokenProbs; + } + + destroyFunc(vocabPtr); + return result; + } + Pointer ptr; } diff --git a/flutter/sherpa_onnx/lib/src/online_recognizer.dart b/flutter/sherpa_onnx/lib/src/online_recognizer.dart index 14fe4a45ec..b75ba386f2 100644 --- a/flutter/sherpa_onnx/lib/src/online_recognizer.dart +++ b/flutter/sherpa_onnx/lib/src/online_recognizer.dart @@ -327,32 +327,37 @@ class OnlineRecognizerConfig { class OnlineRecognizerResult { OnlineRecognizerResult( - {required this.text, required this.tokens, required this.timestamps}); + {required this.text, required this.tokens, required this.timestamps, this.ysProbs = const []}); factory OnlineRecognizerResult.fromJson(Map json) { return OnlineRecognizerResult( - text: json['text'] as String, - tokens: List.from(json['tokens'] as List), - timestamps: (json['timestamps'] as List) - .map((e) => (e as num).toDouble()) - .toList(), + text: json['text'] as String? ?? '', + tokens: (json['tokens'] as List?)?.cast().toList() ?? const [], + timestamps: (json['timestamps'] as List?) + ?.map((e) => (e as num).toDouble()) + .toList() ?? const [], + ysProbs: (json['ys_probs'] as List?) + ?.map((e) => (e as num).toDouble()) + .toList() ?? const [], ); } @override String toString() { - return 'OnlineRecognizerResult(text: $text, tokens: $tokens, timestamps: $timestamps)'; + return 'OnlineRecognizerResult(text: $text, tokens: $tokens, timestamps: $timestamps, ysProbs: $ysProbs)'; } Map toJson() => { 'text': text, 'tokens': tokens, 'timestamps': timestamps, + 'ys_probs': ysProbs, }; final String text; final List tokens; final List timestamps; + final List ysProbs; } class OnlineRecognizer { @@ -496,10 +501,9 @@ class OnlineRecognizer { SherpaOnnxBindings.destroyOnlineStreamResultJson?.call(json); - return OnlineRecognizerResult( - text: parsedJson['text'], - tokens: List.from(parsedJson['tokens']), - timestamps: List.from(parsedJson['timestamps'])); + return OnlineRecognizerResult.fromJson( + parsedJson as Map, + ); } void reset(OnlineStream stream) { diff --git a/flutter/sherpa_onnx/lib/src/online_stream.dart b/flutter/sherpa_onnx/lib/src/online_stream.dart index e1f61e15c1..6fbd7a5d64 100644 --- a/flutter/sherpa_onnx/lib/src/online_stream.dart +++ b/flutter/sherpa_onnx/lib/src/online_stream.dart @@ -37,5 +37,49 @@ class OnlineStream { SherpaOnnxBindings.onlineStreamInputFinished?.call(ptr); } + Map>? getVocabLogProbs() { + final getFunc = SherpaOnnxBindings.getOnlineStreamVocabLogProbs; + final destroyFunc = SherpaOnnxBindings.destroyVocabLogProbs; + + if (getFunc == null || destroyFunc == null) { + return null; + } + + final vocabPtr = getFunc(this.ptr); + if (vocabPtr == nullptr) { + return null; + } + + final vocabLogProbs = vocabPtr.ref; + final numTokens = vocabLogProbs.numTokens; + final vocabSize = vocabLogProbs.vocabSize; + + // Defensive validation for native values + if (numTokens < 0 || + vocabSize < 0 || + numTokens > 10000 || + vocabSize > 100000) { + destroyFunc(vocabPtr); + return null; + } + + final Map> result = {}; + + for (int tokenIdx = 0; tokenIdx < numTokens; tokenIdx++) { + final List tokenProbs = []; + + for (int vocabIdx = 0; vocabIdx < vocabSize; vocabIdx++) { + final index = tokenIdx * vocabSize + vocabIdx; + final logProb = vocabLogProbs.logProbs[index]; + tokenProbs.add(logProb); + } + + result['token_$tokenIdx'] = tokenProbs; + } + + destroyFunc(vocabPtr); + return result; + } + Pointer ptr; } diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index c01a16eb01..89c0445272 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -776,6 +776,17 @@ final class SherpaOnnxSpokenLanguageIdentificationResult extends Struct { final class SherpaOnnxSpokenLanguageIdentification extends Opaque {} +final class SherpaOnnxVocabLogProbs extends Struct { + // Flattened 2D array + external Pointer logProbs; + + @Int32() + external int numTokens; + + @Int32() + external int vocabSize; +} + final class SherpaOnnxOfflineSpeechDenoiser extends Opaque {} typedef SherpaOnnxCreateOfflineSpeechDenoiserNative = @@ -1724,6 +1735,20 @@ typedef SherpaOnnxGetGitSha1 = SherpaOnnxGetGitSha1Native; typedef SherpaOnnxGetGitDateNative = Pointer Function(); typedef SherpaOnnxGetGitDate = SherpaOnnxGetGitDateNative; +typedef SherpaOnnxOnlineStreamGetVocabLogProbsNative + = Pointer Function( + Pointer stream); + +typedef SherpaOnnxOfflineStreamGetVocabLogProbsNative + = Pointer Function( + Pointer stream); + +typedef SherpaOnnxDestroyVocabLogProbsNative = Void Function( + Pointer logProbs); + +typedef SherpaOnnxDestroyVocabLogProbsDart = void Function( + Pointer); + class SherpaOnnxBindings { static SherpaOnnxCreateOfflineSpeechDenoiser? sherpaOnnxCreateOfflineSpeechDenoiser; @@ -1951,6 +1976,12 @@ class SherpaOnnxBindings { static SherpaOnnxGetGitSha1? getGitSha1; static SherpaOnnxGetGitDate? getGitDate; + static SherpaOnnxOnlineStreamGetVocabLogProbsNative? + getOnlineStreamVocabLogProbs; + static SherpaOnnxOfflineStreamGetVocabLogProbsNative? + getOfflineStreamVocabLogProbs; + static SherpaOnnxDestroyVocabLogProbsDart? destroyVocabLogProbs; + static void init(DynamicLibrary dynamicLibrary) { sherpaOnnxCreateOfflineSpeechDenoiser ??= dynamicLibrary .lookup>( @@ -2674,5 +2705,23 @@ class SherpaOnnxBindings { 'SherpaOnnxGetGitDate', ) .asFunction(); + + getOnlineStreamVocabLogProbs ??= dynamicLibrary + .lookup>( + 'SherpaOnnxOnlineStreamGetVocabLogProbs', + ) + .asFunction(); + + getOfflineStreamVocabLogProbs ??= dynamicLibrary + .lookup>( + 'SherpaOnnxOfflineStreamGetVocabLogProbs', + ) + .asFunction(); + + destroyVocabLogProbs ??= dynamicLibrary + .lookup>( + 'SherpaOnnxDestroyVocabLogProbs', + ) + .asFunction(); } } diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 83bc1e0172..38063fede4 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -857,6 +857,65 @@ void SherpaOnnxDestroyOfflineStreamResultJson(const char *s) { delete[] s; } +const struct SherpaOnnxVocabLogProbs *SherpaOnnxOnlineStreamGetVocabLogProbs( + const SherpaOnnxOnlineStream *stream) { + const auto &result = stream->impl->GetResult(); + + if (result.vocab_log_probs.empty()) { + return nullptr; + } + + auto vocab_probs = new SherpaOnnxVocabLogProbs; + vocab_probs->num_tokens = result.vocab_log_probs.size(); + vocab_probs->vocab_size = result.vocab_log_probs[0].size(); + + // Flatten the 2D vector into a 1D array + float *flat_probs = + new float[vocab_probs->num_tokens * vocab_probs->vocab_size]; + for (int32_t i = 0; i < vocab_probs->num_tokens; ++i) { + std::copy(result.vocab_log_probs[i].begin(), + result.vocab_log_probs[i].end(), + flat_probs + i * vocab_probs->vocab_size); + } + vocab_probs->log_probs = flat_probs; + + return vocab_probs; +} + +const struct SherpaOnnxVocabLogProbs *SherpaOnnxOfflineStreamGetVocabLogProbs( + const SherpaOnnxOfflineStream *stream) { + const sherpa_onnx::OfflineRecognitionResult &result = + stream->impl->GetResult(); + + if (result.vocab_log_probs.empty()) { + return nullptr; + } + + auto vocab_probs = new SherpaOnnxVocabLogProbs; + vocab_probs->num_tokens = result.vocab_log_probs.size(); + vocab_probs->vocab_size = result.vocab_log_probs[0].size(); + + // Flatten the 2D vector into a 1D array + float *flat_probs = + new float[vocab_probs->num_tokens * vocab_probs->vocab_size]; + for (int32_t i = 0; i < vocab_probs->num_tokens; ++i) { + std::copy(result.vocab_log_probs[i].begin(), + result.vocab_log_probs[i].end(), + flat_probs + i * vocab_probs->vocab_size); + } + vocab_probs->log_probs = flat_probs; + + return vocab_probs; +} + +void SherpaOnnxDestroyVocabLogProbs( + const struct SherpaOnnxVocabLogProbs *log_probs) { + if (log_probs) { + delete[] log_probs->log_probs; + delete log_probs; + } +} + // ============================================================ // For Keyword Spot // ============================================================ diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 61532d2e87..798c1c6e3d 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -729,6 +729,21 @@ typedef struct SherpaOnnxOfflineRecognizerResult { int32_t segment_count; } SherpaOnnxOfflineRecognizerResult; +SHERPA_ONNX_API typedef struct SherpaOnnxVocabLogProbs { + const float *log_probs; // Flattened 2D array [num_tokens][vocab_size] + int32_t num_tokens; + int32_t vocab_size; +} SherpaOnnxVocabLogProbs; + +SHERPA_ONNX_API const SherpaOnnxVocabLogProbs * +SherpaOnnxOnlineStreamGetVocabLogProbs(const SherpaOnnxOnlineStream *stream); + +SHERPA_ONNX_API const SherpaOnnxVocabLogProbs * +SherpaOnnxOfflineStreamGetVocabLogProbs(const SherpaOnnxOfflineStream *stream); + +SHERPA_ONNX_API void SherpaOnnxDestroyVocabLogProbs( + const SherpaOnnxVocabLogProbs *log_probs); + /// Get the result of the offline stream. /// /// We assume you have called SherpaOnnxDecodeOfflineStream() or diff --git a/sherpa-onnx/csrc/offline-ctc-decoder.h b/sherpa-onnx/csrc/offline-ctc-decoder.h index c9d1b36ffa..381aba9e51 100644 --- a/sherpa-onnx/csrc/offline-ctc-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-decoder.h @@ -26,6 +26,11 @@ struct OfflineCtcDecoderResult { /// /// tokens.size() == timestamps.size() std::vector timestamps; + + /// Token-level log probabilities (confidence scores). + /// May be empty if not provided by the decoder. + /// If populated, token_log_probs.size() == tokens.size() + std::vector token_log_probs; }; class OfflineCtcDecoder { diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc index 59d16f5d32..cf2abf76bf 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc @@ -32,16 +32,18 @@ std::vector OfflineCtcGreedySearchDecoder::Decode( int64_t prev_id = -1; for (int32_t t = 0; t != static_cast(p_log_probs_length[b]); ++t) { - auto y = static_cast(std::distance( + auto max_it = std::max_element( static_cast(p_log_probs), - std::max_element( - static_cast(p_log_probs), - static_cast(p_log_probs) + vocab_size))); + static_cast(p_log_probs) + vocab_size); + auto y = static_cast(std::distance( + static_cast(p_log_probs), max_it)); + float log_prob = *max_it; p_log_probs += vocab_size; if (y != blank_id_ && y != prev_id) { r.tokens.push_back(y); r.timestamps.push_back(t); + r.token_log_probs.push_back(log_prob); } prev_id = y; } // for (int32_t t = 0; ...) diff --git a/sherpa-onnx/csrc/offline-moonshine-decoder.h b/sherpa-onnx/csrc/offline-moonshine-decoder.h index 4d0b9ac93d..1b4b3f15cb 100644 --- a/sherpa-onnx/csrc/offline-moonshine-decoder.h +++ b/sherpa-onnx/csrc/offline-moonshine-decoder.h @@ -14,6 +14,9 @@ namespace sherpa_onnx { struct OfflineMoonshineDecoderResult { /// The decoded token IDs std::vector tokens; + /// Token-level log probabilities (confidence scores) + std::vector token_log_probs; + std::vector> vocab_log_probs; }; class OfflineMoonshineDecoder { diff --git a/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc index dba090238f..4c3d8c15e9 100644 --- a/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/offline-moonshine-greedy-search-decoder.h" #include +#include #include #include @@ -38,6 +39,8 @@ OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) { int32_t seq_len = 1; std::vector tokens; + std::vector token_log_probs; + std::vector> vocab_log_probs; std::array token_shape = {1, 1}; int64_t seq_len_shape = 1; @@ -59,12 +62,36 @@ OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) { for (int32_t i = 0; i != max_len; ++i) { const float *p = logits.GetTensorData(); - int32_t max_token_id = static_cast( - std::distance(p, std::max_element(p, p + vocab_size))); + // Compute log-softmax once for both max selection and storage + float max_logit = *std::max_element(p, p + vocab_size); + + double sum_exp = 0.0; + for (int32_t j = 0; j < vocab_size; ++j) { + sum_exp += std::exp(p[j] - max_logit); + } + float log_sum = max_logit + static_cast(std::log(sum_exp)); + + // Compute log-softmax for all tokens and find max in single pass + std::vector full_vocab_probs(vocab_size); + int32_t max_token_id = 0; + float max_log_prob = p[0] - log_sum; + full_vocab_probs[0] = max_log_prob; + + for (int32_t j = 1; j < vocab_size; ++j) { + float log_prob = p[j] - log_sum; + full_vocab_probs[j] = log_prob; + if (log_prob > max_log_prob) { + max_log_prob = log_prob; + max_token_id = j; + } + } + if (max_token_id == eos) { break; } tokens.push_back(max_token_id); + token_log_probs.push_back(max_log_prob); + vocab_log_probs.push_back(std::move(full_vocab_probs)); seq_len += 1; @@ -87,6 +114,8 @@ OfflineMoonshineGreedySearchDecoder::Decode(Ort::Value encoder_out) { OfflineMoonshineDecoderResult ans; ans.tokens = std::move(tokens); + ans.token_log_probs = std::move(token_log_probs); + ans.vocab_log_probs = std::move(vocab_log_probs); return {ans}; } diff --git a/sherpa-onnx/csrc/offline-recognizer-canary-impl.h b/sherpa-onnx/csrc/offline-recognizer-canary-impl.h index 8744899c6a..3793dd5ac5 100644 --- a/sherpa-onnx/csrc/offline-recognizer-canary-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-canary-impl.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CANARY_IMPL_H_ #include +#include #include #include #include @@ -56,6 +57,13 @@ class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { void DecodeStream(OfflineStream *s) const { auto meta = model_->GetModelMetadata(); auto enc_out = RunEncoder(s); + + if (enc_out.empty()) { + OfflineRecognitionResult empty_result; + s->SetResult(empty_result); + return; + } + Ort::Value enc_states = std::move(enc_out[0]); Ort::Value enc_mask = std::move(enc_out[2]); // enc_out[1] is discarded @@ -69,7 +77,10 @@ class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { View(&enc_states), View(&enc_mask)); } - int32_t max_token_id = GetMaxTokenId(&logits); + std::vector full_vocab_probs; + auto [max_token_id, confidence] = + GetMaxTokenIdWithConfidence(&logits, &full_vocab_probs); + int32_t eos = symbol_table_["<|endoftext|>"]; int32_t num_feature_frames = @@ -77,27 +88,50 @@ class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { meta.subsampling_factor; std::vector tokens = {max_token_id}; + std::vector token_log_probs = {confidence}; + std::vector> vocab_log_probs; + vocab_log_probs.push_back(std::move(full_vocab_probs)); // Assume 30 tokens per second. It is to avoid the following for loop // running indefinitely. int32_t num_tokens = static_cast(num_feature_frames / 100.0 * 30) + 1; - for (int32_t i = 1; i <= num_tokens; ++i) { - if (tokens.back() == eos) { - break; - } + // Reserve space to reduce reallocations + tokens.reserve(num_tokens + 1); + token_log_probs.reserve(num_tokens + 1); + vocab_log_probs.reserve(num_tokens + 1); - std::tie(logits, decoder_states) = - RunDecoder(tokens.back(), i, std::move(decoder_states), - View(&enc_states), View(&enc_mask)); - tokens.push_back(GetMaxTokenId(&logits)); + if (max_token_id != eos) { + for (int32_t i = 1; i <= num_tokens; ++i) { + if (tokens.back() == eos) { + break; + } + + std::tie(logits, decoder_states) = + RunDecoder(tokens.back(), i, std::move(decoder_states), + View(&enc_states), View(&enc_mask)); + + std::vector next_full_vocab_probs; + auto [next_token_id, next_confidence] = + GetMaxTokenIdWithConfidence(&logits, &next_full_vocab_probs); + + tokens.push_back(next_token_id); + token_log_probs.push_back(next_confidence); + vocab_log_probs.push_back(std::move(next_full_vocab_probs)); + } } - // remove the last eos token - tokens.pop_back(); + // remove the last eos token and its confidence + if (!tokens.empty() && tokens.back() == eos) { + tokens.pop_back(); + token_log_probs.pop_back(); + vocab_log_probs.pop_back(); + } - auto r = Convert(tokens); + // Convert with vocab_log_probs - filtering happens in one place for alignment + // Move vocab_log_probs since they won't be used after this + auto r = Convert(tokens, token_log_probs, std::move(vocab_log_probs)); r.text = ApplyInverseTextNormalization(std::move(r.text)); r.text = ApplyHomophoneReplacer(std::move(r.text)); @@ -116,19 +150,36 @@ class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { } private: - OfflineRecognitionResult Convert(const std::vector &tokens) const { + OfflineRecognitionResult Convert( + const std::vector &tokens, + const std::vector &token_log_probs, + std::vector> vocab_log_probs = {}) const { OfflineRecognitionResult r; r.tokens.reserve(tokens.size()); + if (!vocab_log_probs.empty()) { + r.vocab_log_probs.reserve(tokens.size()); + } std::string text; - for (auto i : tokens) { - if (!symbol_table_.Contains(i)) { + for (size_t idx = 0; idx < tokens.size(); ++idx) { + int32_t token_id = tokens[idx]; + + if (!symbol_table_.Contains(token_id)) { continue; } - const auto &s = symbol_table_[i]; + const auto &s = symbol_table_[token_id]; text += s; r.tokens.push_back(s); + + if (idx < token_log_probs.size()) { + r.token_log_probs.push_back(token_log_probs[idx]); + } + + // Filter vocab_log_probs in the same loop to maintain alignment + if (!vocab_log_probs.empty() && idx < vocab_log_probs.size()) { + r.vocab_log_probs.push_back(std::move(vocab_log_probs[idx])); + } } r.text = std::move(text); @@ -136,15 +187,52 @@ class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { return r; } - int32_t GetMaxTokenId(Ort::Value *logits) const { + std::pair GetMaxTokenIdWithConfidence( + Ort::Value *logits, + std::vector *full_distribution = nullptr) const { // logits is of shape (1, 1, vocab_size) auto meta = model_->GetModelMetadata(); const float *p_logits = logits->GetTensorData(); - int32_t max_token_id = static_cast(std::distance( - p_logits, std::max_element(p_logits, p_logits + meta.vocab_size))); + // Find max for numerical stability + float max_logit = *std::max_element(p_logits, p_logits + meta.vocab_size); + + // Compute log_softmax using double for intermediate precision + double sum_exp = 0.0; + for (int32_t i = 0; i < meta.vocab_size; ++i) { + sum_exp += std::exp(p_logits[i] - max_logit); + } + float log_sum = max_logit + static_cast(std::log(sum_exp)); + + // Find the max token and its log probability + // If full_distribution is requested, compute both in a single pass + int32_t max_token_id = 0; + float max_log_prob = p_logits[0] - log_sum; + + if (full_distribution != nullptr) { + full_distribution->resize(meta.vocab_size); + full_distribution->at(0) = max_log_prob; + + for (int32_t i = 1; i < meta.vocab_size; ++i) { + float log_prob = p_logits[i] - log_sum; + full_distribution->at(i) = log_prob; + if (log_prob > max_log_prob) { + max_log_prob = log_prob; + max_token_id = i; + } + } + } else { + // Only find max if full distribution not needed + for (int32_t i = 1; i < meta.vocab_size; ++i) { + float log_prob = p_logits[i] - log_sum; + if (log_prob > max_log_prob) { + max_log_prob = log_prob; + max_token_id = i; + } + } + } - return max_token_id; + return {max_token_id, max_log_prob}; } std::vector RunEncoder(OfflineStream *s) const { @@ -154,6 +242,16 @@ class OfflineRecognizerCanaryImpl : public OfflineRecognizerImpl { int32_t feat_dim = config_.feat_config.feature_dim; std::vector f = s->GetFrames(); + if (f.empty()) { + return {}; + } + + // Validate feat_dim to prevent division by zero + if (feat_dim <= 0) { + SHERPA_ONNX_LOGE("Invalid feature dimension: %d", feat_dim); + return {}; + } + int32_t num_frames = f.size() / feat_dim; std::array shape = {1, num_frames, feat_dim}; diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 8647e658d8..fc2cbcc50a 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -31,6 +31,8 @@ OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.timestamps.size()); + r.token_log_probs.reserve(src.token_log_probs.size()); + std::string text; for (int32_t i = 0; i != src.tokens.size(); ++i) { @@ -57,6 +59,11 @@ OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, } r.tokens.push_back(std::move(sym)); + + // Add confidence score if available + if (i < src.token_log_probs.size()) { + r.token_log_probs.push_back(src.token_log_probs[i]); + } } if (sym_table.IsByteBpe()) { diff --git a/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h b/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h index 7ae7210802..7a9d40ff9e 100644 --- a/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-moonshine-impl.h @@ -28,9 +28,13 @@ OfflineRecognitionResult Convert(const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); + if (!src.vocab_log_probs.empty()) { + r.vocab_log_probs.reserve(src.tokens.size()); + } std::string text; - for (auto i : src.tokens) { + for (size_t idx = 0; idx < src.tokens.size(); ++idx) { + auto i = src.tokens[idx]; if (!sym_table.Contains(i)) { continue; } @@ -38,6 +42,14 @@ OfflineRecognitionResult Convert(const OfflineMoonshineDecoderResult &src, const auto &s = sym_table[i]; text += s; r.tokens.push_back(s); + + if (idx < src.token_log_probs.size()) { + r.token_log_probs.push_back(src.token_log_probs[idx]); + } + + if (!src.vocab_log_probs.empty() && idx < src.vocab_log_probs.size()) { + r.vocab_log_probs.push_back(src.vocab_log_probs[idx]); + } } r.text = text; diff --git a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h index 056c79f00f..ae2983a33e 100644 --- a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h @@ -53,6 +53,14 @@ OfflineRecognitionResult ConvertSenseVoiceResult( r.words = std::move(src.words); + // Propagate token-level log probabilities (skipping the first 4 control + // tokens for non-NanO models, same as tokens/timestamps above) + if (!src.token_log_probs.empty()) { + for (int32_t i = start; i < static_cast(src.token_log_probs.size()); ++i) { + r.token_log_probs.push_back(src.token_log_probs[i]); + } + } + if (!is_funasr_nano) { // parse lang, emotion and event from tokens. if (src.tokens.size() >= 3) { diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index a12b8d7c39..7996118d83 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -152,6 +152,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { const SymbolTable &sym_table) const { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); + r.token_log_probs.reserve(src.token_log_probs.size()); + r.vocab_log_probs.reserve(src.vocab_log_probs.size()); std::string text; @@ -161,7 +163,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { config_.model_config.whisper.enable_segment_timestamps; // Build text, skipping timestamp tokens if in segment timestamp mode - for (auto i : src.tokens) { + for (size_t idx = 0; idx < src.tokens.size(); ++idx) { + auto i = src.tokens[idx]; // Skip timestamp tokens (they are >= timestamp_begin) if (enable_segment_timestamps && i >= timestamp_begin) { continue; @@ -177,6 +180,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { text += s; r.tokens.push_back(s); + if (idx < src.token_log_probs.size()) { + r.token_log_probs.push_back(src.token_log_probs[idx]); + } + if (idx < src.vocab_log_probs.size()) { + r.vocab_log_probs.push_back(src.vocab_log_probs[idx]); + } } r.text = text; diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index f33bd2f589..606c1397fb 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -419,6 +419,21 @@ std::string OfflineRecognitionResult::AsJsonString() const { } os << "], "; + // Serialize token_log_probs (custom field, may be same as ys_log_probs) + if (!token_log_probs.empty()) { + os << "\"" + << "token_log_probs" + << "\"" + << ": "; + os << "["; + sep = ""; + for (auto prob : token_log_probs) { + os << sep << std::fixed << std::setprecision(4) << prob; + sep = ", "; + } + os << "], "; + } + sep = ""; os << "\"" diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 5ef7e579b4..58b94fdff2 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -56,6 +56,10 @@ struct OfflineRecognitionResult { std::vector segment_texts; // text of each segment std::string AsJsonString() const; + + /// Token-level probabilities + std::vector token_log_probs; + std::vector> vocab_log_probs; }; struct WhisperTag { diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index dfc8ae7756..e7dc31ac22 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" #include +#include #include #include @@ -153,8 +154,6 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, } } - std::vector predicted_tokens; - // Storage for accumulated attention weights std::vector> all_attention_weights; int32_t attention_n_heads = 0; @@ -194,12 +193,46 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, int32_t num_possible_tokens = num_feature_frames / 100.0 * 6; num_possible_tokens = std::min(num_possible_tokens, n_text_ctx / 2); + std::vector predicted_tokens; + // Log probabilities. + std::vector predicted_log_probs; + std::vector> predicted_vocab_log_probs; + + // Reserve capacity to avoid reallocations + predicted_tokens.reserve(num_possible_tokens); + predicted_log_probs.reserve(num_possible_tokens); + predicted_vocab_log_probs.reserve(num_possible_tokens); + for (int32_t i = 0; i < num_possible_tokens; ++i) { if (max_token_id == eot) { break; } + // Compute log-softmax for the full vocabulary from raw logits + const float *raw_logits = std::get<0>(decoder_out).GetTensorData(); + // For initial iteration, logits have shape (1, n_tokens, vocab_size), + // take last token; for subsequent iterations, shape is (1, 1, vocab_size) + auto cur_logits_shape = std::get<0>(decoder_out).GetTensorTypeAndShapeInfo().GetShape(); + const float *current_logits = raw_logits + (cur_logits_shape[1] - 1) * vocab_size; + + std::vector full_vocab_probs(vocab_size); + auto max_iter = std::max_element(current_logits, current_logits + vocab_size); + float max_logit = *max_iter; + double sum_exp = 0.0; + for (int32_t j = 0; j < vocab_size; ++j) { + sum_exp += std::exp(current_logits[j] - max_logit); + } + float log_sum = max_logit + std::log(sum_exp); + for (int32_t j = 0; j < vocab_size; ++j) { + full_vocab_probs[j] = current_logits[j] - log_sum; + } + + // Extract log probability for the selected token + float log_prob = full_vocab_probs[max_token_id]; + predicted_tokens.push_back(max_token_id); + predicted_log_probs.push_back(log_prob); + predicted_vocab_log_probs.push_back(std::move(full_vocab_probs)); all_tokens.push_back(max_token_id); // Track if this is a timestamp token (for filtering in DTW) @@ -274,6 +307,8 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, } ans[0].tokens = std::move(predicted_tokens); + ans[0].token_log_probs = std::move(predicted_log_probs); + ans[0].vocab_log_probs = std::move(predicted_vocab_log_probs); // Parse timestamp tokens into segments if using segment timestamp mode if (enable_segment_timestamps) { diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.h b/sherpa-onnx/csrc/offline-whisper-model-config.h index c286021ffe..7070aa50a8 100644 --- a/sherpa-onnx/csrc/offline-whisper-model-config.h +++ b/sherpa-onnx/csrc/offline-whisper-model-config.h @@ -81,7 +81,11 @@ struct OfflineWhisperSegment { struct OfflineWhisperDecoderResult { /// The decoded token IDs std::vector tokens; + /// The log probabilities for each token + std::vector token_log_probs; std::string lang; + /// Full vocabulary log probabilities at each token position + std::vector> vocab_log_probs; /// Cross-attention weights for token-level timestamps (if enabled) /// Shape: (n_heads, n_tokens, n_audio_frames), flattened to 1D diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 64b28aba5d..72968f92a7 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -76,6 +76,7 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.ys_probs = std::move(src.ys_probs); r.lm_probs = std::move(src.lm_probs); r.context_scores = std::move(src.context_scores); + r.vocab_log_probs = src.vocab_log_probs; r.segment = segment; r.start_time = frames_since_start * frame_shift_ms / 1000.; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 99b648c02e..f30fbd1477 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -51,7 +51,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), config_.blank_penalty); + model_.get(), config_.blank_penalty, config_.temperature_scale); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -75,7 +75,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { } if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), config_.blank_penalty); + model_.get(), config_.blank_penalty, config_.temperature_scale); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index d98c114fbf..6f1b7dac7f 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -42,6 +42,7 @@ struct OnlineRecognizerResult { // /// log-domain scores from "hot-phrase" contextual boosting std::vector context_scores; + std::vector> vocab_log_probs; std::vector words; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.cc b/sherpa-onnx/csrc/online-transducer-decoder.cc index 682b9bc7eb..7bed58bd6f 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-decoder.cc @@ -40,6 +40,7 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( ys_probs = other.ys_probs; lm_probs = other.lm_probs; context_scores = other.context_scores; + vocab_log_probs = other.vocab_log_probs; return *this; } @@ -67,6 +68,7 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( ys_probs = std::move(other.ys_probs); lm_probs = std::move(other.lm_probs); context_scores = std::move(other.context_scores); + vocab_log_probs = std::move(other.vocab_log_probs); return *this; } diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index e507a0fc4f..098020ba06 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -30,6 +30,10 @@ struct OnlineTransducerDecoderResult { std::vector lm_probs; std::vector context_scores; + /// Shape: (num_emitted_tokens, vocab_size) + /// Empty if confidence calculation is disabled + std::vector> vocab_log_probs; + // Cache decoder_out for endpointing Ort::Value decoder_out; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index f609825be3..4a3e083ce0 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -138,15 +138,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( r.tokens.push_back(y); r.timestamps.push_back(t + r.frame_offset); r.num_trailing_blanks = 0; - } else { - ++r.num_trailing_blanks; - } - // export the per-token log scores - if (y != 0 && y != unk_id_) { + // export the per-token log scores // apply temperature-scaling + float temp = temperature_scale_ > 0.0f ? temperature_scale_ : 1.0f; for (int32_t n = 0; n < vocab_size; ++n) { - p_logit[n] /= temperature_scale_; + p_logit[n] /= temp; } LogSoftmax(p_logit, vocab_size); // renormalize probabilities, // save time by doing it only for @@ -155,6 +152,12 @@ void OnlineTransducerGreedySearchDecoder::Decode( // now it contains normalized // probability r.ys_probs.push_back(p_logprob[y]); + + // Store full vocabulary distribution + std::vector full_vocab_probs(p_logprob, p_logprob + vocab_size); + r.vocab_log_probs.push_back(std::move(full_vocab_probs)); + } else { + ++r.num_trailing_blanks; } } if (emitted) { diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index ab445f027e..56dc446f13 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -31,7 +31,13 @@ static Ort::Value BuildDecoderInput(int32_t token, OrtAllocator *allocator) { static void DecodeOne(const float *encoder_out, int32_t num_rows, int32_t num_cols, OnlineTransducerNeMoModel *model, - float blank_penalty, OnlineStream *s) { + float blank_penalty, float temperature_scale, + OnlineStream *s) { + // Defensive: temperature must be > 0. Treat invalid values as "no scaling". + if (temperature_scale <= 0.0f) { + temperature_scale = 1.0f; + } + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -86,6 +92,22 @@ static void DecodeOne(const float *encoder_out, int32_t num_rows, r.timestamps.push_back(t + r.frame_offset); r.num_trailing_blanks = 0; + // Export the per-token log scores + // Copy logits before modifying to avoid issues with tensor data + // Note: p_logit already includes blank_penalty adjustment (applied at line 81) + // so vocab_log_probs will contain adjusted probabilities, not raw model outputs + std::vector logits_copy(p_logit, p_logit + vocab_size); + if (temperature_scale != 1.0f) { + for (int32_t n = 0; n < vocab_size; ++n) { + logits_copy[n] /= temperature_scale; + } + } + LogSoftmax(logits_copy.data(), vocab_size); + r.ys_probs.push_back(logits_copy[y]); + + // Store full vocabulary distribution (includes blank penalty and temperature scaling) + r.vocab_log_probs.push_back(std::move(logits_copy)); + decoder_input = BuildDecoderInput(y, model->Allocator()); // last decoder state becomes the current state for the first chunk @@ -123,7 +145,8 @@ void OnlineTransducerGreedySearchNeMoDecoder::Decode(Ort::Value encoder_out, for (int32_t i = 0; i != batch_size; ++i) { const float *this_p = p + dim1 * dim2 * i; - DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, ss[i]); + DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, temperature_scale_, + ss[i]); } } diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index 212008fdd1..de71ca5235 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -18,8 +18,11 @@ class OnlineStream; class OnlineTransducerGreedySearchNeMoDecoder { public: OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, - float blank_penalty) - : model_(model), blank_penalty_(blank_penalty) {} + float blank_penalty, + float temperature_scale) + : model_(model), + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} // @param n number of elements in ss void Decode(Ort::Value encoder_out, OnlineStream **ss, int32_t n) const; @@ -27,6 +30,7 @@ class OnlineTransducerGreedySearchNeMoDecoder { private: OnlineTransducerNeMoModel *model_; // Not owned float blank_penalty_; + float temperature_scale_; }; } // namespace sherpa_onnx