From 7b0e945d7e197e3b258e3ed964a24ff6011a5d5e Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:28:58 +0800 Subject: [PATCH 01/12] feat: add Kaldi-compatible fbank feature extraction Port compute_fbank_kaldi from backup branch as compute_kaldi_fbank with KaldiFbankConfig (sample_rate u32, Povey window, DC removal, natural log, negative high_freq Kaldi convention). Registered in features::mod.rs. Co-Authored-By: Claude Sonnet 4.6 --- src/features/kaldi_fbank.rs | 197 ++++++++++++++++++++++++++++++++++++ src/features/mod.rs | 2 + 2 files changed, 199 insertions(+) create mode 100644 src/features/kaldi_fbank.rs diff --git a/src/features/kaldi_fbank.rs b/src/features/kaldi_fbank.rs new file mode 100644 index 0000000..9a48169 --- /dev/null +++ b/src/features/kaldi_fbank.rs @@ -0,0 +1,197 @@ +//! Kaldi-compatible FBank feature extraction. +//! +//! Implements the same feature pipeline used by sherpa-onnx / kaldi-native-fbank: +//! Povey window (Hamming^0.85), DC offset removal, preemphasis, power spectrum, +//! triangular mel filterbank, and natural log. The `high_freq` field follows the +//! Kaldi sign convention: negative values are treated as `nyquist + high_freq` +//! (e.g. `-400` → 7600 Hz at 16 kHz). + +use ndarray::Array2; +use rustfft::{num_complex::Complex, FftPlanner}; + +/// Kaldi-compatible FBank configuration matching sherpa-onnx / kaldi-native-fbank. +#[derive(Debug, Clone)] +pub struct KaldiFbankConfig { + pub num_bins: usize, + pub fft_size: usize, + pub window_size: usize, + pub hop_size: usize, + pub sample_rate: u32, + pub low_freq: f32, + /// Negative means nyquist + high_freq (Kaldi convention). -400 → 7600 Hz at 16kHz. + pub high_freq: f32, + pub preemph_coeff: f32, + pub snip_edges: bool, + pub remove_dc_offset: bool, +} + +impl Default for KaldiFbankConfig { + fn default() -> Self { + Self { + num_bins: 80, + fft_size: 512, + window_size: 400, + hop_size: 160, + sample_rate: 16000, + low_freq: 20.0, + high_freq: -400.0, + preemph_coeff: 0.97, + snip_edges: false, + remove_dc_offset: true, + } + } +} + +/// Compute Kaldi-compatible FBank features from audio samples. +/// +/// Returns an array of shape `[num_frames, num_bins]` (time-major). +pub fn compute_kaldi_fbank(samples: &[f32], config: &KaldiFbankConfig) -> Array2 { + let window_size = config.window_size; + let hop_size = config.hop_size; + let fft_size = config.fft_size; + let half_fft = fft_size / 2 + 1; + + if samples.is_empty() { + return Array2::zeros((0, config.num_bins)); + } + + let num_frames = if config.snip_edges { + if samples.len() < window_size { + return Array2::zeros((0, config.num_bins)); + } + (samples.len() - window_size) / hop_size + 1 + } else { + (samples.len() + hop_size / 2) / hop_size + }; + + if num_frames == 0 { + return Array2::zeros((0, config.num_bins)); + } + + let filterbank = mel_filterbank(config); + + // Povey window: hamming^0.85 + let window: Vec = (0..window_size) + .map(|i| { + let hamming = 0.54 + - 0.46 + * (2.0 * std::f32::consts::PI * i as f32 / (window_size as f32 - 1.0)).cos(); + hamming.powf(0.85) + }) + .collect(); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(fft_size); + + let mut features = Vec::with_capacity(num_frames * config.num_bins); + + for frame_idx in 0..num_frames { + let center = if config.snip_edges { + frame_idx * hop_size + window_size / 2 + } else { + frame_idx * hop_size + }; + let start = center as isize - (window_size as isize / 2); + + let mut frame = vec![0.0f32; window_size]; + for i in 0..window_size { + let idx = start + i as isize; + if idx >= 0 && (idx as usize) < samples.len() { + frame[i] = samples[idx as usize]; + } + } + + if config.remove_dc_offset { + let mean: f32 = frame.iter().sum::() / window_size as f32; + for s in frame.iter_mut() { + *s -= mean; + } + } + + if config.preemph_coeff > 0.0 { + for i in (1..window_size).rev() { + frame[i] -= config.preemph_coeff * frame[i - 1]; + } + frame[0] *= 1.0 - config.preemph_coeff; + } + + let mut buffer: Vec> = frame + .iter() + .zip(window.iter()) + .map(|(&s, &w)| Complex::new(s * w, 0.0)) + .collect(); + buffer.resize(fft_size, Complex::new(0.0, 0.0)); + fft.process(&mut buffer); + + let power: Vec = buffer[..half_fft].iter().map(|c| c.norm_sqr()).collect(); + + for filter in &filterbank { + let mut sum = 0.0f32; + for (i, &w) in filter.iter().enumerate() { + sum += w * power[i]; + } + features.push(if sum > f32::EPSILON { + sum.ln() + } else { + f32::EPSILON.ln() + }); + } + } + + Array2::from_shape_vec((num_frames, config.num_bins), features).unwrap() +} + +/// Build a triangular mel filterbank matrix. +/// +/// Returns a `Vec` of `num_bins` filters, each of length `fft_size / 2 + 1`. +fn mel_filterbank(config: &KaldiFbankConfig) -> Vec> { + let num_bins = config.num_bins; + let fft_size = config.fft_size; + let sample_rate = config.sample_rate as f32; + let nyquist = sample_rate / 2.0; + let low_freq = config.low_freq; + let high_freq = if config.high_freq <= 0.0 { + nyquist + config.high_freq + } else { + config.high_freq + }; + + let hz_to_mel = |hz: f32| 1127.0 * (1.0 + hz / 700.0).ln(); + let mel_to_hz = |mel: f32| 700.0 * ((mel / 1127.0).exp() - 1.0); + + let low_mel = hz_to_mel(low_freq); + let high_mel = hz_to_mel(high_freq); + + let num_points = num_bins + 2; + let mel_points: Vec = (0..num_points) + .map(|i| low_mel + (high_mel - low_mel) * i as f32 / (num_points - 1) as f32) + .collect(); + let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); + let fft_bins: Vec = hz_points + .iter() + .map(|&hz| ((hz * fft_size as f32) / sample_rate).floor() as usize) + .collect(); + + let half_fft = fft_size / 2 + 1; + let mut filterbank = vec![vec![0.0f32; half_fft]; num_bins]; + for (i, filter) in filterbank.iter_mut().enumerate() { + let left = fft_bins[i]; + let center = fft_bins[i + 1]; + let right = fft_bins[i + 2]; + if center > left { + for j in left..center { + if j < half_fft { + filter[j] = (j - left) as f32 / (center - left) as f32; + } + } + } + if right > center { + for j in center..right { + if j < half_fft { + filter[j] = (right - j) as f32 / (right - center) as f32; + } + } + } + } + filterbank +} diff --git a/src/features/mod.rs b/src/features/mod.rs index bc9d16b..c7a894c 100644 --- a/src/features/mod.rs +++ b/src/features/mod.rs @@ -1,7 +1,9 @@ mod cmvn; +pub mod kaldi_fbank; mod lfr; mod mel; pub use cmvn::apply_cmvn; +pub use kaldi_fbank::{compute_kaldi_fbank, KaldiFbankConfig}; pub use lfr::apply_lfr; pub use mel::{compute_mel, MelConfig, WindowType}; From 1d0f5f491df23cec9cefc0d5d4b1d176056c5b20 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:30:33 +0800 Subject: [PATCH 02/12] feat: add BBPE symbol table for Icefall/sherpa-onnx models Co-Authored-By: Claude Sonnet 4.6 --- src/decode/bbpe.rs | 175 +++++++++++++++++++++++++++++++++++++++++++++ src/decode/mod.rs | 2 + 2 files changed, 177 insertions(+) create mode 100644 src/decode/bbpe.rs diff --git a/src/decode/bbpe.rs b/src/decode/bbpe.rs new file mode 100644 index 0000000..8d7af7a --- /dev/null +++ b/src/decode/bbpe.rs @@ -0,0 +1,175 @@ +//! BBPE (Byte-level BPE) symbol table for Icefall/sherpa-onnx Zipformer models. +//! +//! Handles byte-to-unicode mapping and auto-detects encoding mode based on +//! whether a `bbpe.model` file is present alongside the tokens file. + +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +/// Whether tokens use BBPE byte encoding or standard BPE (literal UTF-8). +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TokenEncoding { + Bbpe, + Bpe, +} + +/// Symbol table for Icefall/sherpa-onnx Zipformer models with BBPE support. +pub struct BbpeSymbolTable { + id_to_sym: HashMap, + encoding: TokenEncoding, +} + +impl BbpeSymbolTable { + /// Load a symbol table, auto-detecting encoding mode. + /// + /// If a `bbpe.model` file exists in the same directory as `path`, + /// BBPE encoding is used; otherwise standard BPE is assumed. + pub fn load(path: &Path) -> Result { + Self::load_autodetect(path) + } + + /// Load with explicit auto-detection of encoding based on sibling files. + pub fn load_autodetect(path: &Path) -> Result { + let encoding = if let Some(dir) = path.parent() { + if dir.join("bbpe.model").exists() { + log::info!("Detected BBPE encoding (bbpe.model found)"); + TokenEncoding::Bbpe + } else { + log::info!("Detected standard BPE encoding (no bbpe.model)"); + TokenEncoding::Bpe + } + } else { + TokenEncoding::Bbpe + }; + Self::load_with_encoding(path, encoding) + } + + /// Load with an explicitly specified encoding. + pub fn load_with_encoding(path: &Path, encoding: TokenEncoding) -> Result { + let contents = fs::read_to_string(path)?; + let mut id_to_sym = HashMap::new(); + for line in contents.lines() { + let line = line.trim_end(); + if line.is_empty() { + continue; + } + let parts: Vec<&str> = line.rsplitn(2, |c: char| c.is_whitespace()).collect(); + if parts.len() != 2 { + continue; + } + if let Ok(id) = parts[0].parse::() { + id_to_sym.insert(id, parts[1].to_string()); + } + } + Ok(Self { id_to_sym, encoding }) + } + + /// Look up a symbol by token ID. + pub fn get(&self, id: i32) -> Option<&str> { + self.id_to_sym.get(&id).map(|s| s.as_str()) + } + + /// Decode a sequence of token IDs to a UTF-8 string. + pub fn decode(&self, token_ids: &[i32]) -> String { + match self.encoding { + TokenEncoding::Bbpe => self.decode_bbpe(token_ids), + TokenEncoding::Bpe => self.decode_bpe(token_ids), + } + } + + fn decode_bbpe(&self, token_ids: &[i32]) -> String { + let mut raw_bytes = Vec::new(); + for &id in token_ids { + let Some(sym) = self.get(id) else { continue }; + if sym.starts_with('<') && sym.ends_with('>') { + continue; + } + for c in sym.chars() { + if c == '\u{2581}' { + raw_bytes.push(b' '); + } else if let Some(byte_val) = bbpe_char_to_byte(c) { + raw_bytes.push(byte_val); + } + } + } + let text = String::from_utf8_lossy(&raw_bytes); + normalize_text(text.trim()) + } + + fn decode_bpe(&self, token_ids: &[i32]) -> String { + let mut text = String::new(); + for &id in token_ids { + let Some(sym) = self.get(id) else { continue }; + if sym.starts_with('<') && sym.ends_with('>') { + continue; + } + text.push_str(&sym.replace('\u{2581}', " ")); + } + normalize_text(text.trim()) + } +} + +fn is_cjk(c: char) -> bool { + matches!(c, + '\u{4E00}'..='\u{9FFF}' | + '\u{3400}'..='\u{4DBF}' | + '\u{F900}'..='\u{FAFF}' | + '\u{2E80}'..='\u{2EFF}' | + '\u{3000}'..='\u{303F}' | + '\u{FF00}'..='\u{FFEF}' + ) +} + +fn normalize_text(text: &str) -> String { + let text = text.to_lowercase(); + let chars: Vec = text.chars().collect(); + let mut result = String::with_capacity(text.len()); + for i in 0..chars.len() { + let c = chars[i]; + if c == ' ' { + let prev_cjk = i > 0 && is_cjk(chars[i - 1]); + let next_cjk = i + 1 < chars.len() && is_cjk(chars[i + 1]); + if prev_cjk && next_cjk { + continue; + } + } + result.push(c); + } + result +} + +/// BBPE codepoint table: maps byte value (index) to Unicode codepoint. +const BBPE_CODEPOINTS: [u32; 256] = [ + 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, + 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, + 286, 287, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, + 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 288, 289, 290, 291, 292, + 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 308, 309, + 310, 311, 312, 313, 314, 315, 316, 317, 318, 321, 322, 323, 324, 325, 326, + 327, 328, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, + 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, + 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, + 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, + 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, + 419, 420, 421, 422, +]; + +/// Convert a BBPE-encoded Unicode character back to its original byte value. +pub fn bbpe_char_to_byte(c: char) -> Option { + let cp = c as u32; + if (32..=126).contains(&cp) { + return Some(cp as u8); + } + for (byte_val, &mapped_cp) in BBPE_CODEPOINTS.iter().enumerate() { + if mapped_cp == cp { + return Some(byte_val as u8); + } + } + None +} diff --git a/src/decode/mod.rs b/src/decode/mod.rs index f33bf39..9ff8655 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -1,7 +1,9 @@ +pub mod bbpe; mod ctc; mod sentencepiece; pub mod tokens; pub use ctc::{ctc_greedy_decode, CtcDecoderResult}; pub use sentencepiece::sentencepiece_to_text; +pub use bbpe::BbpeSymbolTable; pub use tokens::{load_vocab, SymbolTable}; From 8081c169d32dc9a98a7aaa2d07036deb60ae8ee5 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:34:39 +0800 Subject: [PATCH 03/12] feat: add Paraformer ONNX engine Non-autoregressive ASR model with custom fbank (Hamming/dB scale), LFR stacking, mean-only CMVN, and @@-subword symbol table decoding. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/onnx/mod.rs | 1 + src/onnx/paraformer/mod.rs | 635 +++++++++++++++++++++++++++++++++++++ 2 files changed, 636 insertions(+) create mode 100644 src/onnx/paraformer/mod.rs diff --git a/src/onnx/mod.rs b/src/onnx/mod.rs index 98ee2fc..85fef31 100644 --- a/src/onnx/mod.rs +++ b/src/onnx/mod.rs @@ -22,5 +22,6 @@ pub enum Quantization { pub mod canary; pub mod gigaam; pub mod moonshine; +pub mod paraformer; pub mod parakeet; pub mod sense_voice; diff --git a/src/onnx/paraformer/mod.rs b/src/onnx/paraformer/mod.rs new file mode 100644 index 0000000..bb396f3 --- /dev/null +++ b/src/onnx/paraformer/mod.rs @@ -0,0 +1,635 @@ +//! Paraformer ONNX speech recognition engine. +//! +//! Non-autoregressive end-to-end ASR model from FunASR/ModelScope. +//! Uses its own fbank feature extraction (Hamming window, dB scale), +//! LFR frame stacking, and a custom symbol table with `@@` subword joining. + +use ndarray::{Array1, Array2}; +use ort::inputs; +use ort::session::Session; +use ort::value::TensorRef; +use rustfft::{num_complex::Complex, FftPlanner}; +use std::collections::HashMap; +use std::f32::consts::PI; +use std::path::Path; + +use super::session; +use super::Quantization; +use crate::features::apply_lfr; +use crate::TranscribeError; +use crate::{ModelCapabilities, SpeechModel, TranscribeOptions, TranscriptionResult}; + +const CAPABILITIES: ModelCapabilities = ModelCapabilities { + name: "Paraformer", + engine_id: "paraformer", + sample_rate: 16000, + languages: &["zh", "en", "yue"], + supports_timestamps: false, + supports_translation: false, + supports_streaming: false, +}; + +/// Per-model inference parameters for Paraformer. +#[derive(Debug, Clone, Default)] +pub struct ParaformerParams { + /// Language hint (currently unused, Paraformer handles zh/en/yue automatically). + pub language: Option, +} + +// ---- Metadata ---- + +struct ParaformerMetadata { + lfr_window_size: usize, + lfr_window_shift: usize, + blank_id: i32, + sos_id: i32, + eos_id: i32, +} + +// ---- Symbol Table ---- + +/// Paraformer-specific symbol table with `@@` subword joining and CJK-aware spacing. +struct ParaformerSymbolTable { + id_to_sym: HashMap, +} + +impl ParaformerSymbolTable { + /// Load a tokens.txt file where each line is `symbol id`. + fn load(path: &Path) -> Result { + let content = std::fs::read_to_string(path)?; + let mut id_to_sym = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + // Split on last whitespace: "symbol id" + if let Some(pos) = line.rfind(|c: char| c.is_ascii_whitespace()) { + let sym = &line[..pos]; + let id_str = line[pos..].trim(); + if let Ok(id) = id_str.parse::() { + id_to_sym.insert(id, sym.to_string()); + } + } + } + + log::info!( + "Loaded Paraformer symbol table with {} tokens from {:?}", + id_to_sym.len(), + path + ); + Ok(Self { id_to_sym }) + } + + fn get(&self, id: i32) -> Option<&str> { + self.id_to_sym.get(&id).map(|s| s.as_str()) + } + + /// Decode a sequence of token IDs into text. + fn decode(&self, token_ids: &[i32]) -> String { + let mut text = String::new(); + let mut prev_join_to_next = false; + + for &id in token_ids { + let Some(sym) = self.get(id) else { + continue; + }; + if is_special_symbol(sym) { + continue; + } + + let joins_next = sym.ends_with("@@"); + let clean = sym.trim_end_matches("@@"); + if clean.is_empty() { + prev_join_to_next = joins_next; + continue; + } + + if clean.starts_with('\u{2581}') { + let piece = clean.trim_start_matches('\u{2581}'); + if !piece.is_empty() { + if !text.is_empty() && !text.ends_with(' ') { + text.push(' '); + } + text.push_str(piece); + } + prev_join_to_next = joins_next; + continue; + } + + if !text.is_empty() && !prev_join_to_next { + let prev_char = text.chars().last(); + let curr_is_ascii_word = is_ascii_word_piece(clean); + let prev_is_ascii_word = prev_char.map(is_ascii_word_char).unwrap_or(false); + let prev_is_cjk = prev_char.map(is_cjk).unwrap_or(false); + if curr_is_ascii_word + && (prev_is_ascii_word || prev_is_cjk) + && !text.ends_with(' ') + { + text.push(' '); + } + } + + text.push_str(clean); + prev_join_to_next = joins_next; + } + + text.trim().to_string() + } +} + +fn is_special_symbol(sym: &str) -> bool { + sym == "" + || sym == "" + || sym == "" + || sym == "" + || sym == "" + || (sym.starts_with('<') && sym.ends_with('>')) +} + +fn is_ascii_word_piece(s: &str) -> bool { + !s.is_empty() && s.chars().all(is_ascii_word_char) +} + +fn is_ascii_word_char(c: char) -> bool { + c.is_ascii_alphanumeric() +} + +fn is_cjk(c: char) -> bool { + let code = c as u32; + (0x4E00..=0x9FFF).contains(&code) + || (0x3400..=0x4DBF).contains(&code) + || (0x20000..=0x2A6DF).contains(&code) + || (0x2A700..=0x2B73F).contains(&code) + || (0x2B740..=0x2B81F).contains(&code) + || (0x2B820..=0x2CEAF).contains(&code) +} + +// ---- Feature Extraction ---- + +/// Compute Paraformer-style fbank features (Hamming window, dB scale). +/// +/// This is NOT the same as Kaldi fbank or the upstream `compute_mel()`. +/// Key differences: +/// - Standard Hamming window (no Povey modification) +/// - No preemphasis, no DC offset removal +/// - dB scale output: `10.0 * log10(energy)` with -80 dB floor +/// - snip_edges=true +/// +/// Returns [num_frames, num_bins] (80 bins by default). +fn compute_paraformer_fbank(samples: &[f32]) -> Array2 { + const NUM_BINS: usize = 80; + const FFT_SIZE: usize = 512; + const WINDOW_SIZE: usize = 400; + const HOP_SIZE: usize = 160; + const SAMPLE_RATE: f32 = 16000.0; + const LOW_FREQ: f32 = 0.0; + const HIGH_FREQ: f32 = 8000.0; + + if samples.len() < WINDOW_SIZE { + return Array2::zeros((0, NUM_BINS)); + } + + let num_frames = (samples.len() - WINDOW_SIZE) / HOP_SIZE + 1; + let num_fft_bins = FFT_SIZE / 2 + 1; + + // Standard Hamming window + let window: Vec = (0..WINDOW_SIZE) + .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f32 / (WINDOW_SIZE as f32 - 1.0)).cos()) + .collect(); + + // Mel filterbank [NUM_BINS, num_fft_bins] + let mel_banks = mel_filterbank(NUM_BINS, FFT_SIZE, SAMPLE_RATE, LOW_FREQ, HIGH_FREQ); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(FFT_SIZE); + + let mut features = Array2::zeros((num_frames, NUM_BINS)); + + for i in 0..num_frames { + let start = i * HOP_SIZE; + + // Extract and window the frame + let mut fft_input: Vec> = Vec::with_capacity(FFT_SIZE); + for j in 0..WINDOW_SIZE { + let sample = if start + j < samples.len() { + samples[start + j] + } else { + 0.0 + }; + fft_input.push(Complex::new(sample * window[j], 0.0)); + } + // Zero-pad to FFT_SIZE + fft_input.resize(FFT_SIZE, Complex::new(0.0, 0.0)); + + fft.process(&mut fft_input); + + // Power spectrum + let power_spectrum: Vec = fft_input[..num_fft_bins] + .iter() + .map(|c| c.norm_sqr()) + .collect(); + + // Apply mel filterbank and convert to dB scale + for m in 0..NUM_BINS { + let energy: f32 = mel_banks + .row(m) + .iter() + .zip(power_spectrum.iter()) + .map(|(&w, &p)| w * p) + .sum(); + + // dB scale: 10 * log10(energy), with -80 dB floor + let db = if energy < 1.0e-10 { + -80.0 + } else { + 10.0 * energy.log10() + }; + features[[i, m]] = db.max(-80.0); + } + } + + features +} + +/// Compute mel filterbank matrix of shape [num_mels, num_fft_bins]. +fn mel_filterbank( + num_mels: usize, + fft_size: usize, + sample_rate: f32, + low_freq: f32, + high_freq: f32, +) -> Array2 { + let num_fft_bins = fft_size / 2 + 1; + + let mel_low = hz_to_mel(low_freq); + let mel_high = hz_to_mel(high_freq); + + let num_points = num_mels + 2; + let mel_points: Vec = (0..num_points) + .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (num_points - 1) as f32) + .collect(); + + let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); + + let bin_points: Vec = hz_points + .iter() + .map(|&f| f * fft_size as f32 / sample_rate) + .collect(); + + let mut banks = Array2::zeros((num_mels, num_fft_bins)); + + for m in 0..num_mels { + let left = bin_points[m]; + let center = bin_points[m + 1]; + let right = bin_points[m + 2]; + + for k in 0..num_fft_bins { + let kf = k as f32; + if kf > left && kf < center { + banks[[m, k]] = (kf - left) / (center - left); + } else if kf >= center && kf < right { + banks[[m, k]] = (right - kf) / (right - center); + } + } + } + + banks +} + +fn hz_to_mel(hz: f32) -> f32 { + 1127.0 * (1.0 + hz / 700.0).ln() +} + +fn mel_to_hz(mel: f32) -> f32 { + 700.0 * ((mel / 1127.0).exp() - 1.0) +} + +// ---- CMVN ---- + +/// Apply mean-only CMVN normalization (subtract mean, no stddev scaling). +fn apply_mean_cmvn(features: &mut Array2, mean: &Array1) { + let ncols = features.ncols(); + for mut row in features.rows_mut() { + for j in 0..ncols { + row[j] -= mean[j]; + } + } +} + +/// Load CMVN mean from an `am.mvn` file (Kaldi-style format). +/// +/// Parses the `` section and extracts the mean vector. +/// Returns `None` if the file doesn't contain the expected format. +fn load_cmvn_mean(path: &Path, target_dim: usize) -> Result>, std::io::Error> { + let content = std::fs::read_to_string(path)?; + let Some(start_idx) = content.find("") else { + return Ok(None); + }; + let rest = &content[start_idx..]; + let Some(lb_rel) = rest.find('[') else { + return Ok(None); + }; + let Some(rb_rel) = rest.find(']') else { + return Ok(None); + }; + if rb_rel <= lb_rel { + return Ok(None); + } + let body = &rest[lb_rel + 1..rb_rel]; + let mut values = Vec::new(); + for tok in body.split_whitespace() { + if let Ok(v) = tok.parse::() { + values.push(v); + } + } + if values.len() < target_dim { + return Ok(None); + } + Ok(Some(Array1::from_vec( + values.into_iter().take(target_dim).collect(), + ))) +} + +// ---- Model ---- + +pub struct ParaformerModel { + session: Session, + symbol_table: ParaformerSymbolTable, + metadata: ParaformerMetadata, + cmvn_mean: Option>, + speech_input_name: String, + speech_lengths_input_name: String, + #[allow(dead_code)] + logits_output_name: String, + #[allow(dead_code)] + token_num_output_name: Option, +} + +impl ParaformerModel { + pub fn load(model_dir: &Path, quantization: &Quantization) -> Result { + let model_path = session::resolve_model_path(model_dir, "model", quantization); + let tokens_path = model_dir.join("tokens.txt"); + let cmvn_path = model_dir.join("am.mvn"); + + if !model_path.exists() { + return Err(TranscribeError::ModelNotFound(model_path)); + } + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!("Loading Paraformer model from {:?}...", model_path); + let session = session::create_session(&model_path)?; + + // Read metadata from ONNX model + let lfr_window_size = + session::read_metadata_i32(&session, "lfr_window_size", Some(7))?.unwrap() as usize; + let lfr_window_shift = + session::read_metadata_i32(&session, "lfr_window_shift", Some(6))?.unwrap() as usize; + let blank_id = session::read_metadata_i32(&session, "blank_id", Some(0))?.unwrap(); + let sos_id = session::read_metadata_i32(&session, "sos_id", Some(1))?.unwrap(); + let eos_id = session::read_metadata_i32(&session, "eos_id", Some(2))?.unwrap(); + + let metadata = ParaformerMetadata { + lfr_window_size, + lfr_window_shift, + blank_id, + sos_id, + eos_id, + }; + + log::info!( + "Paraformer metadata: lfr_window={}x{}, blank={}, sos={}, eos={}", + lfr_window_size, + lfr_window_shift, + blank_id, + sos_id, + eos_id, + ); + + // Detect I/O names from session + let inputs: Vec = session.inputs().iter().map(|i| i.name().to_string()).collect(); + let outputs: Vec = session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + let speech_input_name = inputs + .iter() + .find(|n| n.contains("speech")) + .cloned() + .unwrap_or_else(|| inputs.first().cloned().unwrap_or_else(|| "speech".to_string())); + + let speech_lengths_input_name = inputs + .iter() + .find(|n| n.contains("speech_lengths") || n.contains("length")) + .cloned() + .unwrap_or_else(|| { + inputs + .get(1) + .cloned() + .unwrap_or_else(|| "speech_lengths".to_string()) + }); + + let logits_output_name = outputs + .iter() + .find(|n| n.contains("logits")) + .cloned() + .unwrap_or_else(|| outputs.first().cloned().unwrap_or_else(|| "logits".to_string())); + + let token_num_output_name = outputs + .iter() + .find(|n| n.contains("token_num")) + .cloned(); + + log::debug!( + "I/O names: speech={}, lengths={}, logits={}, token_num={:?}", + speech_input_name, + speech_lengths_input_name, + logits_output_name, + token_num_output_name, + ); + + // Load symbol table + let symbol_table = ParaformerSymbolTable::load(&tokens_path)?; + + // Load CMVN mean (LFR dim = 80 * lfr_window_size) + let lfr_dim = 80 * lfr_window_size; + let cmvn_mean = if cmvn_path.exists() { + match load_cmvn_mean(&cmvn_path, lfr_dim) { + Ok(mean) => { + if mean.is_some() { + log::info!("Loaded CMVN mean from {:?} (dim={})", cmvn_path, lfr_dim); + } + mean + } + Err(e) => { + log::warn!("Failed to load CMVN from {:?}: {}", cmvn_path, e); + None + } + } + } else { + log::debug!("No am.mvn file found at {:?}, skipping CMVN", cmvn_path); + None + }; + + Ok(Self { + session, + symbol_table, + metadata, + cmvn_mean, + speech_input_name, + speech_lengths_input_name, + logits_output_name, + token_num_output_name, + }) + } + + /// Transcribe with model-specific parameters. + pub fn transcribe_with( + &mut self, + samples: &[f32], + _params: &ParaformerParams, + ) -> Result { + self.infer(samples) + } + + fn infer(&mut self, samples: &[f32]) -> Result { + // 1. Compute Paraformer fbank features [frames, 80] + let features = compute_paraformer_fbank(samples); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + log::debug!( + "Paraformer fbank: [{}, {}]", + features.nrows(), + features.ncols() + ); + + // 2. Apply LFR + let features = apply_lfr( + &features, + self.metadata.lfr_window_size, + self.metadata.lfr_window_shift, + ); + + log::debug!("After LFR: [{}, {}]", features.nrows(), features.ncols()); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + // 3. Apply mean-only CMVN + let mut features = features; + if let Some(ref mean) = self.cmvn_mean { + apply_mean_cmvn(&mut features, mean); + } + + // 4. Forward pass + let logits = self.forward(&features)?; + + log::debug!("Logits shape: {:?}", logits.shape()); + + // 5. Decode (non-autoregressive argmax, NOT CTC) + let token_ids = self.decode_logits(&logits); + + // 6. Convert to text + let text = self.symbol_table.decode(&token_ids); + + Ok(TranscriptionResult { + text, + segments: None, + }) + } + + /// Run ONNX forward pass. Returns logits [1, T, vocab_size]. + fn forward( + &mut self, + features: &Array2, + ) -> Result, TranscribeError> { + let num_frames = features.nrows(); + + // Shape: [1, T, D] + let feat_3d = features + .to_owned() + .into_shape_with_order((1, num_frames, features.ncols()))?; + let speech_lengths = ndarray::arr1(&[num_frames as i32]); + + let feat_dyn = feat_3d.into_dyn(); + let lengths_dyn = speech_lengths.into_dyn(); + + let t_feat = TensorRef::from_array_view(feat_dyn.view())?; + let t_lengths = TensorRef::from_array_view(lengths_dyn.view())?; + + let ort_inputs = inputs![ + self.speech_input_name.as_str() => t_feat, + self.speech_lengths_input_name.as_str() => t_lengths, + ]; + + let outputs = self.session.run(ort_inputs)?; + + let logits = outputs[0].try_extract_array::()?; + let logits_owned = logits.to_owned().into_dimensionality::()?; + + Ok(logits_owned) + } + + /// Decode logits using argmax with blank/sos/eos filtering. + /// + /// Paraformer is non-autoregressive, so this is a simple argmax over the + /// vocab dimension, NOT CTC greedy decode. + fn decode_logits(&self, logits: &ndarray::Array3) -> Vec { + let blank_id = self.metadata.blank_id; + let sos_id = self.metadata.sos_id; + let eos_id = self.metadata.eos_id; + + let seq_len = logits.shape()[1]; + let mut token_ids = Vec::new(); + + for t in 0..seq_len { + // Argmax over vocab dimension + let mut best_id = 0i32; + let mut best_val = f32::NEG_INFINITY; + for (v, &val) in logits.slice(ndarray::s![0, t, ..]).iter().enumerate() { + if val > best_val { + best_val = val; + best_id = v as i32; + } + } + + // Skip blank, sos, eos + if best_id == blank_id || best_id == sos_id || best_id == eos_id { + continue; + } + + token_ids.push(best_id); + } + + token_ids + } +} + +impl SpeechModel for ParaformerModel { + fn capabilities(&self) -> ModelCapabilities { + CAPABILITIES + } + + fn transcribe_raw( + &mut self, + samples: &[f32], + _options: &TranscribeOptions, + ) -> Result { + self.infer(samples) + } +} From 62c80e52b513000bc9a598424ffb83d70b542728 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:37:44 +0800 Subject: [PATCH 04/12] feat: add Zipformer CTC ONNX engine Implements ZipformerCtcModel with SpeechModel trait, Kaldi fbank feature extraction, and CTC greedy decode using BbpeSymbolTable. Supports both standard model.onnx naming and sherpa-onnx directory-scan fallback. Rejects streaming models that contain cached_* inputs at load time. Co-Authored-By: Claude Sonnet 4.6 --- src/onnx/mod.rs | 1 + src/onnx/zipformer_ctc/mod.rs | 328 ++++++++++++++++++++++++++++++++++ 2 files changed, 329 insertions(+) create mode 100644 src/onnx/zipformer_ctc/mod.rs diff --git a/src/onnx/mod.rs b/src/onnx/mod.rs index 85fef31..d661273 100644 --- a/src/onnx/mod.rs +++ b/src/onnx/mod.rs @@ -25,3 +25,4 @@ pub mod moonshine; pub mod paraformer; pub mod parakeet; pub mod sense_voice; +pub mod zipformer_ctc; diff --git a/src/onnx/zipformer_ctc/mod.rs b/src/onnx/zipformer_ctc/mod.rs new file mode 100644 index 0000000..ac2b7d5 --- /dev/null +++ b/src/onnx/zipformer_ctc/mod.rs @@ -0,0 +1,328 @@ +//! Zipformer CTC ONNX speech recognition engine. +//! +//! Supports sherpa-onnx Zipformer CTC models (e.g. from Icefall) for +//! Chinese and English transcription. Streaming models (with `cached_*` +//! inputs) are rejected at load time. + +use ndarray::Array2; +use ort::inputs; +use ort::session::Session; +use ort::value::TensorRef; +use std::path::{Path, PathBuf}; + +use super::session; +use super::Quantization; +use crate::decode::{ctc_greedy_decode, BbpeSymbolTable}; +use crate::features::{compute_kaldi_fbank, KaldiFbankConfig}; +use crate::TranscribeError; +use crate::{ModelCapabilities, SpeechModel, TranscribeOptions, TranscriptionResult}; + +const CAPABILITIES: ModelCapabilities = ModelCapabilities { + name: "Zipformer CTC", + engine_id: "zipformer_ctc", + sample_rate: 16000, + languages: &["zh", "en"], + supports_timestamps: false, + supports_translation: false, + supports_streaming: false, +}; + +/// Per-model inference parameters for Zipformer CTC. +#[derive(Debug, Clone, Default)] +pub struct ZipformerCtcParams { + /// Language hint (currently unused; the model handles zh/en automatically). + pub language: Option, +} + +// ---- Model ---- + +pub struct ZipformerCtcModel { + session: Session, + symbol_table: BbpeSymbolTable, + blank_id: i64, + x_input_name: String, + x_lens_input_name: String, + #[allow(dead_code)] + log_probs_output_name: String, + /// Index of the output that contains output lengths, if present. + log_probs_len_output_idx: Option, +} + +impl ZipformerCtcModel { + /// Load a Zipformer CTC model from `model_dir`. + /// + /// Attempts `session::resolve_model_path(dir, "model", quantization)` first, + /// then falls back to scanning the directory for any `.onnx` file (for + /// sherpa-onnx models with names like `model-epoch-34-avg-19.int8.onnx`). + pub fn load(model_dir: &Path, quantization: &Quantization) -> Result { + let model_path = Self::find_model_file(model_dir, quantization)?; + let tokens_path = model_dir.join("tokens.txt"); + + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!("Loading Zipformer CTC model from {:?}...", model_path); + let session = session::create_session(&model_path)?; + + // Reject streaming models — they have cached_* inputs + let input_names: Vec = session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + if input_names.iter().any(|n| n.starts_with("cached_")) { + return Err(TranscribeError::Config(format!( + "Streaming Zipformer models are not supported (found cached_* inputs in {:?}). \ + Use a non-streaming (offline) model.", + model_path + ))); + } + + log::debug!("Model inputs: {:?}", input_names); + + let output_names: Vec = session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + log::debug!("Model outputs: {:?}", output_names); + + // Detect input names + let x_input_name = input_names + .iter() + .find(|n| n.as_str() == "x" || n.contains("feat") || n.contains("input")) + .cloned() + .unwrap_or_else(|| { + input_names + .first() + .cloned() + .unwrap_or_else(|| "x".to_string()) + }); + + let x_lens_input_name = input_names + .iter() + .find(|n| n.contains("len") || n.contains("length")) + .cloned() + .unwrap_or_else(|| { + input_names + .get(1) + .cloned() + .unwrap_or_else(|| "x_lens".to_string()) + }); + + // Detect output names + let log_probs_output_name = output_names + .iter() + .find(|n| n.contains("log_prob") || n.contains("logit") || n.contains("prob")) + .cloned() + .unwrap_or_else(|| { + output_names + .first() + .cloned() + .unwrap_or_else(|| "log_probs".to_string()) + }); + + let log_probs_len_output_idx = output_names + .iter() + .position(|n| n.contains("len") || n.contains("length")); + + log::debug!( + "I/O mapping: x={}, x_lens={}, log_probs={}, log_probs_len_idx={:?}", + x_input_name, + x_lens_input_name, + log_probs_output_name, + log_probs_len_output_idx, + ); + + // Load BBPE symbol table (auto-detects BBPE vs BPE encoding) + let symbol_table = BbpeSymbolTable::load(&tokens_path)?; + + // blank_id is always 0 for sherpa-onnx Zipformer CTC models + let blank_id = 0i64; + + Ok(Self { + session, + symbol_table, + blank_id, + x_input_name, + x_lens_input_name, + log_probs_output_name, + log_probs_len_output_idx, + }) + } + + /// Find the ONNX model file in `model_dir`. + /// + /// Priority: + /// 1. `session::resolve_model_path(dir, "model", quantization)` — standard naming + /// 2. Scan directory for `*.int8.onnx` (when Int8 requested) or `*.onnx` + fn find_model_file(model_dir: &Path, quantization: &Quantization) -> Result { + // Try standard path first + let standard_path = session::resolve_model_path(model_dir, "model", quantization); + if standard_path.exists() { + return Ok(standard_path); + } + + // Fallback: scan directory for onnx files + let prefer_int8 = *quantization == Quantization::Int8; + + let read_dir = std::fs::read_dir(model_dir).map_err(|e| { + TranscribeError::Io(std::io::Error::new( + e.kind(), + format!("cannot read model directory {:?}: {}", model_dir, e), + )) + })?; + + let mut int8_candidates: Vec = Vec::new(); + let mut fp32_candidates: Vec = Vec::new(); + + for entry in read_dir.flatten() { + let path = entry.path(); + if path.extension().and_then(|e| e.to_str()) == Some("onnx") { + let name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(""); + if name.contains("int8") || name.contains("int4") { + int8_candidates.push(path); + } else { + fp32_candidates.push(path); + } + } + } + + // Sort for determinism + int8_candidates.sort(); + fp32_candidates.sort(); + + if prefer_int8 { + if let Some(p) = int8_candidates.into_iter().next() { + log::info!("Found int8 model by directory scan: {:?}", p); + return Ok(p); + } + } + + if let Some(p) = fp32_candidates.into_iter().next() { + log::info!("Found model by directory scan: {:?}", p); + return Ok(p); + } + + // Last resort: return the int8 candidate even if fp32 preferred + Err(TranscribeError::ModelNotFound( + model_dir.join("model.onnx"), + )) + } + + /// Transcribe with model-specific parameters. + pub fn transcribe_with( + &mut self, + samples: &[f32], + _params: &ZipformerCtcParams, + ) -> Result { + self.infer(samples) + } + + fn infer(&mut self, samples: &[f32]) -> Result { + // 1. Compute Kaldi FBank features [frames, 80] + let features = compute_kaldi_fbank(samples, &KaldiFbankConfig::default()); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + log::debug!( + "Kaldi fbank: [{}, {}]", + features.nrows(), + features.ncols() + ); + + // 2. Run ONNX forward pass → log_probs [1, T', vocab] + let (log_probs, output_len) = self.forward(&features)?; + + log::debug!( + "log_probs shape: {:?}, output_len={}", + log_probs.shape(), + output_len + ); + + // 3. CTC greedy decode (expects [batch, time, vocab]) + let logits_lengths = vec![output_len]; + let results = ctc_greedy_decode(&log_probs.view(), &logits_lengths, self.blank_id); + + // 4. Convert token IDs (i64 → i32) and decode to text + let token_ids: Vec = results[0].tokens.iter().map(|&t| t as i32).collect(); + let text = self.symbol_table.decode(&token_ids); + + Ok(TranscriptionResult { + text, + segments: None, + }) + } + + /// Run ONNX forward pass. + /// + /// Returns `(log_probs [1, T, vocab], output_len)` where `output_len` is + /// the valid frame count for batch element 0. + fn forward(&mut self, features: &Array2) -> Result<(ndarray::Array3, i64), TranscribeError> { + let num_frames = features.nrows() as i64; + + // Shape: [1, T, 80] + let feat_3d = features + .to_owned() + .into_shape_with_order((1, features.nrows(), features.ncols()))?; + let x_lens = ndarray::arr1(&[num_frames]); + + let feat_dyn = feat_3d.into_dyn(); + let lens_dyn = x_lens.into_dyn(); + + let t_feat = TensorRef::from_array_view(feat_dyn.view())?; + let t_lens = TensorRef::from_array_view(lens_dyn.view())?; + + let ort_inputs = inputs![ + self.x_input_name.as_str() => t_feat, + self.x_lens_input_name.as_str() => t_lens, + ]; + + let outputs = self.session.run(ort_inputs)?; + + // Extract log_probs — always the first output, shape [1, T', vocab] + let log_probs = outputs[0].try_extract_array::()?; + let log_probs = log_probs + .to_owned() + .into_dimensionality::()?; + + // Determine output length: use the length output if available, else T' + let output_len = if let Some(len_idx) = self.log_probs_len_output_idx { + if len_idx < outputs.len() { + let lens = outputs[len_idx].try_extract_array::()?; + lens.first().copied().unwrap_or(log_probs.shape()[1] as i64) + } else { + log_probs.shape()[1] as i64 + } + } else { + log_probs.shape()[1] as i64 + }; + + Ok((log_probs, output_len)) + } +} + +impl SpeechModel for ZipformerCtcModel { + fn capabilities(&self) -> ModelCapabilities { + CAPABILITIES + } + + fn transcribe_raw( + &mut self, + samples: &[f32], + _options: &TranscribeOptions, + ) -> Result { + self.infer(samples) + } +} From 93d6c5c6e7c5bafdff85282f494ad7551505d065 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:40:39 +0800 Subject: [PATCH 05/12] feat: add Zipformer Transducer ONNX engine Three-session RNN-T architecture (encoder, decoder, joiner) with greedy search decoding. Auto-detects I/O names and model file naming conventions. Rejects streaming models at load time. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/onnx/mod.rs | 1 + src/onnx/zipformer_transducer/mod.rs | 528 +++++++++++++++++++++++++++ 2 files changed, 529 insertions(+) create mode 100644 src/onnx/zipformer_transducer/mod.rs diff --git a/src/onnx/mod.rs b/src/onnx/mod.rs index d661273..9eb7b82 100644 --- a/src/onnx/mod.rs +++ b/src/onnx/mod.rs @@ -26,3 +26,4 @@ pub mod paraformer; pub mod parakeet; pub mod sense_voice; pub mod zipformer_ctc; +pub mod zipformer_transducer; diff --git a/src/onnx/zipformer_transducer/mod.rs b/src/onnx/zipformer_transducer/mod.rs new file mode 100644 index 0000000..3591e49 --- /dev/null +++ b/src/onnx/zipformer_transducer/mod.rs @@ -0,0 +1,528 @@ +//! Zipformer Transducer (RNN-T) ONNX speech recognition engine. +//! +//! Supports sherpa-onnx Zipformer Transducer models with 3 ONNX sessions +//! (encoder, decoder, joiner) and RNN-T greedy search decoding. Streaming +//! models (with `cached_*` inputs) are rejected at load time. + +use ndarray::{Array2, Array3, ArrayView1}; +use ort::inputs; +use ort::session::Session; +use ort::value::TensorRef; +use std::path::{Path, PathBuf}; + +use super::session; +use super::Quantization; +use crate::decode::BbpeSymbolTable; +use crate::features::{compute_kaldi_fbank, KaldiFbankConfig}; +use crate::TranscribeError; +use crate::{ModelCapabilities, SpeechModel, TranscribeOptions, TranscriptionResult}; + +const CAPABILITIES: ModelCapabilities = ModelCapabilities { + name: "Zipformer Transducer", + engine_id: "zipformer_transducer", + sample_rate: 16000, + languages: &["zh", "en", "vi", "ru", "ko"], + supports_timestamps: false, + supports_translation: false, + supports_streaming: false, +}; + +/// Per-model inference parameters for Zipformer Transducer. +#[derive(Debug, Clone, Default)] +pub struct ZipformerTransducerParams { + /// Language hint (currently unused; the model handles languages automatically). + pub language: Option, +} + +// ---- Model ---- + +pub struct ZipformerTransducerModel { + encoder_session: Session, + decoder_session: Session, + joiner_session: Session, + symbol_table: BbpeSymbolTable, + blank_id: i32, + context_size: usize, // always 2 + // Encoder I/O names + enc_x_name: String, + enc_x_lens_name: String, + enc_out_name: String, + enc_out_lens_name: String, + // Decoder I/O names + dec_y_name: String, + dec_out_name: String, + // Joiner I/O names + join_enc_name: String, + join_dec_name: String, + join_logit_name: String, +} + +impl ZipformerTransducerModel { + /// Load a Zipformer Transducer model from `model_dir`. + /// + /// Expects three ONNX files (encoder, decoder, joiner) and a `tokens.txt` + /// file in the model directory. + pub fn load(model_dir: &Path, quantization: &Quantization) -> Result { + let suffix = match quantization { + Quantization::FP32 => "fp32", + Quantization::FP16 => "fp16", + Quantization::Int8 => "int8", + }; + + let encoder_path = Self::find_model_file(model_dir, "encoder", suffix)?; + let decoder_path = Self::find_model_file(model_dir, "decoder", suffix)?; + let joiner_path = Self::find_model_file(model_dir, "joiner", suffix)?; + + let tokens_path = model_dir.join("tokens.txt"); + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!( + "Loading Zipformer Transducer model from {:?} (encoder={:?}, decoder={:?}, joiner={:?})...", + model_dir, + encoder_path.file_name().unwrap_or_default(), + decoder_path.file_name().unwrap_or_default(), + joiner_path.file_name().unwrap_or_default(), + ); + + let encoder_session = session::create_session(&encoder_path)?; + let decoder_session = session::create_session(&decoder_path)?; + let joiner_session = session::create_session(&joiner_path)?; + + // Reject streaming models — they have cached_* inputs + let enc_input_names: Vec = encoder_session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + if enc_input_names.iter().any(|n| n.starts_with("cached_")) { + return Err(TranscribeError::Config(format!( + "Streaming Zipformer Transducer models are not supported (found cached_* inputs in {:?}). \ + Use a non-streaming (offline) model.", + encoder_path + ))); + } + + let enc_output_names: Vec = encoder_session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + let dec_input_names: Vec = decoder_session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + let dec_output_names: Vec = decoder_session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + let join_input_names: Vec = joiner_session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + let join_output_names: Vec = joiner_session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + log::debug!("Encoder inputs: {:?}, outputs: {:?}", enc_input_names, enc_output_names); + log::debug!("Decoder inputs: {:?}, outputs: {:?}", dec_input_names, dec_output_names); + log::debug!("Joiner inputs: {:?}, outputs: {:?}", join_input_names, join_output_names); + + // Detect encoder I/O names + let enc_x_name = Self::find_name(&enc_input_names, &["x", "features", "input"]) + .unwrap_or_else(|| enc_input_names.first().cloned().unwrap_or_else(|| "x".to_string())); + + let enc_x_lens_name = + Self::find_name(&enc_input_names, &["x_lens", "x_length", "input_lengths"]) + .unwrap_or_else(|| { + enc_input_names.get(1).cloned().unwrap_or_else(|| "x_lens".to_string()) + }); + + let enc_out_name = + Self::find_name(&enc_output_names, &["encoder_out", "output", "encoder_output"]) + .unwrap_or_else(|| { + enc_output_names + .first() + .cloned() + .unwrap_or_else(|| "encoder_out".to_string()) + }); + + let enc_out_lens_name = Self::find_name( + &enc_output_names, + &["encoder_out_lens", "encoder_out_length", "output_lengths"], + ) + .unwrap_or_else(|| { + enc_output_names + .get(1) + .cloned() + .unwrap_or_else(|| "encoder_out_lens".to_string()) + }); + + // Detect decoder I/O names + let dec_y_name = Self::find_name(&dec_input_names, &["y", "input", "decoder_input"]) + .unwrap_or_else(|| dec_input_names.first().cloned().unwrap_or_else(|| "y".to_string())); + + let dec_out_name = + Self::find_name(&dec_output_names, &["decoder_out", "output", "decoder_output"]) + .unwrap_or_else(|| { + dec_output_names + .first() + .cloned() + .unwrap_or_else(|| "decoder_out".to_string()) + }); + + // Detect joiner I/O names + let join_enc_name = Self::find_name( + &join_input_names, + &["encoder_out", "enc_out", "encoder_input"], + ) + .unwrap_or_else(|| { + join_input_names + .first() + .cloned() + .unwrap_or_else(|| "encoder_out".to_string()) + }); + + let join_dec_name = Self::find_name( + &join_input_names, + &["decoder_out", "dec_out", "decoder_input"], + ) + .unwrap_or_else(|| { + join_input_names + .get(1) + .cloned() + .unwrap_or_else(|| "decoder_out".to_string()) + }); + + let join_logit_name = + Self::find_name(&join_output_names, &["logit", "output", "joiner_output"]) + .unwrap_or_else(|| { + join_output_names + .first() + .cloned() + .unwrap_or_else(|| "logit".to_string()) + }); + + log::debug!( + "I/O mapping: enc({}, {} -> {}, {}), dec({} -> {}), join({}, {} -> {})", + enc_x_name, + enc_x_lens_name, + enc_out_name, + enc_out_lens_name, + dec_y_name, + dec_out_name, + join_enc_name, + join_dec_name, + join_logit_name, + ); + + // Load BBPE symbol table (auto-detects BBPE vs BPE encoding) + let symbol_table = BbpeSymbolTable::load(&tokens_path)?; + + Ok(Self { + encoder_session, + decoder_session, + joiner_session, + symbol_table, + blank_id: 0, + context_size: 2, + enc_x_name, + enc_x_lens_name, + enc_out_name, + enc_out_lens_name, + dec_y_name, + dec_out_name, + join_enc_name, + join_dec_name, + join_logit_name, + }) + } + + /// Find an ONNX model file by component name, trying various naming conventions. + /// + /// Tries in order: + /// 1. `{component}.{suffix}.onnx` + /// 2. `{component}.onnx` + /// 3. Any file starting with `{component}` ending with `.{suffix}.onnx` + /// 4. Any file starting with `{component}` ending with `.onnx` + fn find_model_file( + model_dir: &Path, + component: &str, + suffix: &str, + ) -> Result { + // 1. Try exact: {component}.{suffix}.onnx + let exact_suffixed = model_dir.join(format!("{}.{}.onnx", component, suffix)); + if exact_suffixed.exists() { + return Ok(exact_suffixed); + } + + // 2. Try exact: {component}.onnx + let exact_plain = model_dir.join(format!("{}.onnx", component)); + if exact_plain.exists() { + return Ok(exact_plain); + } + + // 3/4. Scan directory for files matching the component prefix + if let Ok(read_dir) = std::fs::read_dir(model_dir) { + let mut suffixed_candidates: Vec = Vec::new(); + let mut plain_candidates: Vec = Vec::new(); + + for entry in read_dir.flatten() { + let path = entry.path(); + let file_name = match path.file_name().and_then(|n| n.to_str()) { + Some(n) => n.to_string(), + None => continue, + }; + + if !file_name.starts_with(component) { + continue; + } + + if file_name.ends_with(&format!(".{}.onnx", suffix)) { + suffixed_candidates.push(path); + } else if file_name.ends_with(".onnx") { + plain_candidates.push(path); + } + } + + // Sort for determinism + suffixed_candidates.sort(); + plain_candidates.sort(); + + // 3. Prefer suffixed match + if let Some(p) = suffixed_candidates.into_iter().next() { + log::info!( + "Found {} model by directory scan: {:?}", + component, + p.file_name().unwrap_or_default() + ); + return Ok(p); + } + + // 4. Fall back to any .onnx match + if let Some(p) = plain_candidates.into_iter().next() { + log::info!( + "Found {} model by directory scan: {:?}", + component, + p.file_name().unwrap_or_default() + ); + return Ok(p); + } + } + + Err(TranscribeError::ModelNotFound(exact_suffixed)) + } + + /// Find a name from a list of candidates, returning the first match. + fn find_name(names: &[String], candidates: &[&str]) -> Option { + for &candidate in candidates { + if let Some(n) = names.iter().find(|n| n.as_str() == candidate) { + return Some(n.clone()); + } + } + None + } + + /// Transcribe with model-specific parameters. + pub fn transcribe_with( + &mut self, + samples: &[f32], + _params: &ZipformerTransducerParams, + ) -> Result { + self.infer(samples) + } + + fn infer(&mut self, samples: &[f32]) -> Result { + // 1. Compute Kaldi FBank features [frames, 80] + let features = compute_kaldi_fbank(samples, &KaldiFbankConfig::default()); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + log::debug!( + "Kaldi fbank: [{}, {}]", + features.nrows(), + features.ncols() + ); + + // 2. RNN-T greedy search + let token_ids = self.greedy_search(&features)?; + + // 3. Decode tokens to text + let text = self.symbol_table.decode(&token_ids); + + Ok(TranscriptionResult { + text, + segments: None, + }) + } + + /// Run encoder: features [1,T,80] + lens [1] -> encoder_out [1,T',D] + encoder_out_lens [1] + fn run_encoder( + &mut self, + features: &Array2, + ) -> Result<(Array3, i64), TranscribeError> { + let num_frames = features.nrows(); + let feat_3d = features + .to_owned() + .into_shape_with_order((1, num_frames, features.ncols()))?; + let lens = ndarray::arr1(&[num_frames as i64]).into_dyn(); + + let feat_dyn = feat_3d.into_dyn(); + let t_feat = TensorRef::from_array_view(feat_dyn.view())?; + let t_lens = TensorRef::from_array_view(lens.view())?; + + let inputs = inputs![ + self.enc_x_name.as_str() => t_feat, + self.enc_x_lens_name.as_str() => t_lens, + ]; + let outputs = self.encoder_session.run(inputs)?; + + let encoder_out = outputs + .get(self.enc_out_name.as_str()) + .ok_or_else(|| { + TranscribeError::Inference(format!( + "encoder output '{}' not found", + self.enc_out_name + )) + })? + .try_extract_array::()? + .to_owned() + .into_dimensionality::()?; + + let encoder_out_lens = outputs + .get(self.enc_out_lens_name.as_str()) + .and_then(|v| v.try_extract_array::().ok()) + .and_then(|arr| arr.as_slice().and_then(|s| s.first().copied())) + .unwrap_or(encoder_out.shape()[1] as i64); + + Ok((encoder_out, encoder_out_lens)) + } + + /// Run decoder: y [1, context_size] (i64) -> decoder_out [1, D] + fn run_decoder(&mut self, context: &[i64]) -> Result, TranscribeError> { + let y = + ndarray::Array2::from_shape_vec((1, self.context_size), context.to_vec())?; + let y_dyn = y.into_dyn(); + let t_y = TensorRef::from_array_view(y_dyn.view())?; + + let inputs = inputs![ + self.dec_y_name.as_str() => t_y, + ]; + let outputs = self.decoder_session.run(inputs)?; + + let decoder_out = outputs + .get(self.dec_out_name.as_str()) + .ok_or_else(|| { + TranscribeError::Inference(format!( + "decoder output '{}' not found", + self.dec_out_name + )) + })? + .try_extract_array::()? + .to_owned() + .into_dimensionality::()?; + + Ok(decoder_out) + } + + /// Run joiner: encoder_out_frame [1,D] + decoder_out [1,D] -> logit [1, vocab_size] + fn run_joiner( + &mut self, + encoder_out_frame: &ArrayView1, + decoder_out: &Array2, + ) -> Result, TranscribeError> { + let enc = encoder_out_frame + .to_owned() + .into_shape_with_order((1, encoder_out_frame.len()))? + .into_dyn(); + let dec_dyn = decoder_out.clone().into_dyn(); + + let t_enc = TensorRef::from_array_view(enc.view())?; + let t_dec = TensorRef::from_array_view(dec_dyn.view())?; + + let inputs = inputs![ + self.join_enc_name.as_str() => t_enc, + self.join_dec_name.as_str() => t_dec, + ]; + let outputs = self.joiner_session.run(inputs)?; + + let logit = outputs + .get(self.join_logit_name.as_str()) + .ok_or_else(|| { + TranscribeError::Inference(format!( + "joiner output '{}' not found", + self.join_logit_name + )) + })? + .try_extract_array::()? + .to_owned() + .into_dimensionality::()?; + + Ok(logit) + } + + /// Greedy search decoding for RNN-T (transducer). + fn greedy_search(&mut self, features: &Array2) -> Result, TranscribeError> { + let (encoder_out, encoder_out_lens) = self.run_encoder(features)?; + let t_max = (encoder_out_lens as usize).min(encoder_out.shape()[1]); + + let mut context = vec![self.blank_id as i64; self.context_size]; + let mut decoder_out = self.run_decoder(&context)?; + + let mut tokens = Vec::new(); + + for t in 0..t_max { + let enc_frame = encoder_out.slice(ndarray::s![0, t, ..]); + let logit = self.run_joiner(&enc_frame, &decoder_out)?; + + let logit_row = logit.row(0); + let mut max_id = 0usize; + let mut max_val = f32::NEG_INFINITY; + for (i, &v) in logit_row.iter().enumerate() { + if v > max_val { + max_val = v; + max_id = i; + } + } + + if max_id as i32 != self.blank_id { + tokens.push(max_id as i32); + context.rotate_left(1); + *context.last_mut().unwrap() = max_id as i64; + decoder_out = self.run_decoder(&context)?; + } + } + + Ok(tokens) + } +} + +impl SpeechModel for ZipformerTransducerModel { + fn capabilities(&self) -> ModelCapabilities { + CAPABILITIES + } + + fn transcribe_raw( + &mut self, + samples: &[f32], + _options: &TranscribeOptions, + ) -> Result { + self.infer(samples) + } +} From 630152908d1d4f49123fc15d9154c1c4323e192f Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:44:33 +0800 Subject: [PATCH 06/12] feat: add neural punctuation restoration model Implements PunctModel backed by CT-Transformer ONNX, with sliding-window inference (20-token chunks, 2-token overlap) and smart CJK/ASCII punctuation selection. Adds independent `punct` feature gate and updates `all`. Co-Authored-By: Claude Sonnet 4.6 --- Cargo.toml | 5 +- src/error.rs | 4 +- src/lib.rs | 3 + src/punct.rs | 440 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 449 insertions(+), 3 deletions(-) create mode 100644 src/punct.rs diff --git a/Cargo.toml b/Cargo.toml index d1395be..bda5202 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,9 @@ openai = ["dep:async-openai", "dep:tokio", "dep:async-trait"] # Silero neural VAD (requires silero_vad_v4.onnx model file) vad-silero = ["dep:ort", "dep:ndarray"] +# Neural punctuation restoration (CT-Transformer) +punct = ["dep:ort", "dep:ndarray"] + # --- ORT Accelerators --- # Note: ort-cuda pulls in the CUDA execution provider, which adds ~800 MB+ # to the ORT binary and requires a CUDA toolkit / cuDNN installation at runtime. @@ -44,7 +47,7 @@ ort-webgpu = ["onnx", "ort/webgpu"] ort-accel = ["ort-cuda", "ort-directml", "ort-rocm", "ort-coreml", "ort-webgpu"] # Convenience -all = ["onnx", "whisper-cpp", "whisperfile", "openai"] +all = ["onnx", "whisper-cpp", "whisperfile", "openai", "punct"] [dependencies] # Always required diff --git a/src/error.rs b/src/error.rs index d2dc466..a06a24c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -36,14 +36,14 @@ impl From for TranscribeError { } } -#[cfg(any(feature = "onnx", feature = "vad-silero"))] +#[cfg(any(feature = "onnx", feature = "vad-silero", feature = "punct"))] impl From for TranscribeError { fn from(e: ort::Error) -> Self { TranscribeError::Inference(e.to_string()) } } -#[cfg(any(feature = "audio-features", feature = "vad-silero"))] +#[cfg(any(feature = "audio-features", feature = "vad-silero", feature = "punct"))] impl From for TranscribeError { fn from(e: ndarray::ShapeError) -> Self { TranscribeError::Inference(e.to_string()) diff --git a/src/lib.rs b/src/lib.rs index deb229e..090d56f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,6 +116,9 @@ pub mod remote; #[cfg(feature = "openai")] pub use remote::RemoteTranscriptionEngine; +#[cfg(feature = "punct")] +pub mod punct; + use std::path::Path; /// Describes the capabilities of a speech model. diff --git a/src/punct.rs b/src/punct.rs new file mode 100644 index 0000000..e7e4d27 --- /dev/null +++ b/src/punct.rs @@ -0,0 +1,440 @@ +//! Neural network-based Chinese/English punctuation restoration. +//! +//! Uses a CT-Transformer ONNX model to insert punctuation into raw ASR text. +//! The model operates at the character/token level and supports both Chinese +//! (full-width punctuation) and English (ASCII punctuation) contexts. +//! +//! # Feature gate +//! +//! This module requires the `punct` feature: +//! +//! ```toml +//! [dependencies] +//! transcribe-rs = { version = "0.3", features = ["punct"] } +//! ``` +//! +//! # Model files +//! +//! You need the CT-Transformer model directory containing: +//! - `model.int8.onnx` (or `model.onnx` as fallback) +//! - `tokens.json` (vocabulary file as a JSON array of strings) +//! +//! # Usage +//! +//! ```ignore +//! use std::path::Path; +//! use transcribe_rs::punct::PunctModel; +//! +//! let mut model = PunctModel::new(Path::new("models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8/"))?; +//! let result = model.add_punctuation("今天天气很好我们去公园玩吧"); +//! println!("{}", result); +//! // => "今天天气很好,我们去公园玩吧。" +//! ``` + +use std::collections::HashMap; +use std::fs::File; +use std::path::Path; + +use ndarray::{Array1, Array2}; +use ort::inputs; +use ort::session::builder::GraphOptimizationLevel; +use ort::session::Session; +use ort::value::TensorRef; + +use crate::TranscribeError; + +// ── Punctuation type enum ──────────────────────────────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PunctType { + Underscore = 0, + Comma = 2, + Dot = 3, + Quest = 4, + Pause = 5, +} + +impl PunctType { + fn from_id(id: usize) -> Option { + match id { + 0 => Some(PunctType::Underscore), + 2 => Some(PunctType::Comma), + 3 => Some(PunctType::Dot), + 4 => Some(PunctType::Quest), + 5 => Some(PunctType::Pause), + _ => None, + } + } + + fn to_char(self) -> Option { + match self { + PunctType::Underscore => None, + PunctType::Comma => Some(','), + PunctType::Dot => Some('。'), + PunctType::Quest => Some('?'), + PunctType::Pause => Some('、'), + } + } + + fn to_ascii_char(self) -> Option { + match self { + PunctType::Underscore => None, + PunctType::Comma => Some(','), + PunctType::Dot => Some('.'), + PunctType::Quest => Some('?'), + PunctType::Pause => Some(','), + } + } +} + +// ── Token classification ───────────────────────────────────────────────────── + +#[derive(Debug, Clone)] +enum TokenInfo { + Word(String), + Char(char), + Space, + Punct(char), +} + +// ── PunctModel ──────────────────────────────────────────────────────────────── + +/// CT-Transformer ONNX-based punctuation restoration model. +/// +/// Tokenizes input text, runs inference in 20-token windows with 2-token +/// overlap, and reconstructs text with intelligently placed punctuation marks. +pub struct PunctModel { + session: Session, + token2id: HashMap, + unk_id: i32, + input_name: String, + length_name: String, +} + +impl PunctModel { + /// Load a PunctModel from a model directory. + /// + /// The directory must contain `model.int8.onnx` (or `model.onnx`) and + /// `tokens.json`. + pub fn new(model_dir: &Path) -> Result { + let model_path = model_dir.join("model.int8.onnx"); + let model_path = if !model_path.exists() { + model_dir.join("model.onnx") + } else { + model_path + }; + let tokens_path = model_dir.join("tokens.json"); + + if !model_path.exists() { + return Err(TranscribeError::ModelNotFound(model_path)); + } + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!("Loading punctuation model from {:?}...", model_path); + + let session = Session::builder() + .map_err(|e| TranscribeError::Config(format!("ort session builder: {e}")))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| TranscribeError::Config(format!("ort optimization level: {e}")))? + .with_parallel_execution(true) + .map_err(|e| TranscribeError::Config(format!("ort parallel execution: {e}")))? + .commit_from_file(&model_path) + .map_err(|e| TranscribeError::Inference(format!("failed to load punct model: {e}")))?; + + let (token2id, unk_id) = Self::load_tokens(&tokens_path)?; + + // session.inputs() returns &[Outlet] — index directly + let input_name = session.inputs()[0].name().to_string(); + let length_name = session.inputs()[1].name().to_string(); + + log::info!( + "Punct model input names: '{}' and '{}'", + input_name, + length_name + ); + + Ok(Self { + session, + token2id, + unk_id, + input_name, + length_name, + }) + } + + fn load_tokens(path: &Path) -> Result<(HashMap, i32), TranscribeError> { + let file = File::open(path)?; + let tokens: Vec = serde_json::from_reader(file)?; + let mut token2id = HashMap::new(); + for (id, token) in tokens.iter().enumerate() { + token2id.insert(token.clone(), id as i32); + } + let unk_id = *token2id.get("").unwrap_or(&0); + log::info!("Loaded {} tokens, unk_id={}", tokens.len(), unk_id); + Ok((token2id, unk_id)) + } + + /// Tokenize input text into (token_ids, token_infos). + /// + /// - CJK characters are tokenized one character at a time. + /// - ASCII words/digit runs are kept as a single token. + /// - Whitespace is recorded as `TokenInfo::Space` (not submitted to model). + /// - Existing punctuation is recorded as `TokenInfo::Punct` and skipped. + fn tokenize(&self, text: &str) -> (Vec, Vec) { + let mut ids: Vec = Vec::new(); + let mut infos: Vec = Vec::new(); + + let mut chars = text.chars().peekable(); + while let Some(c) = chars.next() { + if c.is_whitespace() { + infos.push(TokenInfo::Space); + continue; + } + + // Existing punctuation — preserve as-is, don't tokenize + if is_existing_punct(c) { + infos.push(TokenInfo::Punct(c)); + continue; + } + + if is_cjk(c) { + // Single CJK character + let token = c.to_string(); + let id = *self.token2id.get(&token).unwrap_or(&self.unk_id); + ids.push(id); + infos.push(TokenInfo::Char(c)); + } else { + // Collect a run of non-CJK, non-space, non-punct chars as a word + let mut word = String::new(); + word.push(c); + while let Some(&nc) = chars.peek() { + if nc.is_whitespace() || is_cjk(nc) || is_existing_punct(nc) { + break; + } + word.push(nc); + chars.next(); + } + let lower = word.to_lowercase(); + let id = self + .token2id + .get(&lower) + .copied() + .unwrap_or_else(|| *self.token2id.get(&word).unwrap_or(&self.unk_id)); + ids.push(id); + infos.push(TokenInfo::Word(word)); + } + } + + (ids, infos) + } + + /// Run punctuation inference on a batch of token IDs. + /// + /// Returns a Vec of punctuation class IDs, one per input token. + fn run_inference(&mut self, token_ids: &[i32]) -> Result, TranscribeError> { + let seq_len = token_ids.len(); + let input_array = Array2::from_shape_vec( + (1, seq_len), + token_ids.iter().map(|&x| x as i64).collect(), + ) + .map_err(|e| TranscribeError::Inference(format!("shape error: {e}")))?; + let length_array = Array1::from_vec(vec![seq_len as i64]); + + let input_tensor = TensorRef::from_array_view(input_array.view()) + .map_err(|e| TranscribeError::Inference(format!("input tensor: {e}")))?; + let length_tensor = TensorRef::from_array_view(length_array.view()) + .map_err(|e| TranscribeError::Inference(format!("length tensor: {e}")))?; + + let outputs = self + .session + .run(inputs![ + self.input_name.as_str() => input_tensor, + self.length_name.as_str() => length_tensor + ]) + .map_err(|e| TranscribeError::Inference(format!("inference: {e}")))?; + + let output = outputs[0] + .try_extract_array::() + .map_err(|e| TranscribeError::Inference(format!("extract output: {e}")))?; + let punct_ids: Vec = output.iter().map(|&x| x as usize).collect(); + Ok(punct_ids) + } + + /// Add punctuation to raw ASR text. + /// + /// On inference error, logs a warning and returns the original text unchanged. + pub fn add_punctuation(&mut self, text: &str) -> String { + if text.is_empty() { + return text.to_string(); + } + + let (token_ids, token_infos) = self.tokenize(text); + + if token_ids.is_empty() { + return text.to_string(); + } + + // Process in windows of 20 tokens with 2-token overlap + const WINDOW: usize = 20; + const OVERLAP: usize = 2; + + let mut punctuations: Vec = vec![0usize; token_ids.len()]; + + let mut start = 0; + loop { + let end = (start + WINDOW).min(token_ids.len()); + let chunk = &token_ids[start..end]; + + match self.run_inference(chunk) { + Ok(preds) => { + // Only take the non-overlapping part (except for the last chunk) + let take = if end < token_ids.len() { + preds.len().saturating_sub(OVERLAP) + } else { + preds.len() + }; + for (i, &p) in preds[..take].iter().enumerate() { + punctuations[start + i] = p; + } + } + Err(e) => { + log::warn!("Punctuation inference failed: {}", e); + return text.to_string(); + } + } + + if end >= token_ids.len() { + break; + } + start += WINDOW - OVERLAP; + } + + self.reconstruct_with_punctuation(&token_infos, &punctuations) + } + + /// Reconstruct the output string by interleaving tokens with predicted punctuation. + fn reconstruct_with_punctuation( + &self, + token_infos: &[TokenInfo], + punctuations: &[usize], + ) -> String { + let mut result = String::new(); + let mut punct_iter = punctuations.iter(); + + for info in token_infos { + match info { + TokenInfo::Space => { + // Spaces are absorbed; punctuation takes their place + } + TokenInfo::Punct(c) => { + result.push(*c); + } + TokenInfo::Char(c) => { + result.push(*c); + if let Some(&pid) = punct_iter.next() { + if let Some(pt) = PunctType::from_id(pid) { + // CJK character → use full-width punctuation + if let Some(pc) = pt.to_char() { + result.push(pc); + } + } + } + } + TokenInfo::Word(w) => { + result.push_str(w); + if let Some(&pid) = punct_iter.next() { + if let Some(pt) = PunctType::from_id(pid) { + let pc = choose_punct_char(pt, w, &result); + if let Some(pc) = pc { + result.push(pc); + } + } + } + } + } + } + + result + } +} + +// ── Helper functions ───────────────────────────────────────────────────────── + +/// Choose full-width or ASCII punctuation based on surrounding context. +fn choose_punct_char(pt: PunctType, current_word: &str, result_so_far: &str) -> Option { + // If the current word is an English/ASCII word, use ASCII punctuation. + // If the preceding content ends in a CJK character, use full-width. + let last_meaningful = result_so_far + .chars() + .rev() + .find(|c| !c.is_whitespace()); + + let use_ascii = is_english_token(current_word) + || last_meaningful + .map(|c| !is_cjk(c) && c.is_ascii_alphanumeric()) + .unwrap_or(false); + + if use_ascii { + pt.to_ascii_char() + } else { + pt.to_char() + } +} + +/// Return true if the character is a punctuation mark that should be preserved. +fn is_existing_punct(c: char) -> bool { + c.is_ascii_punctuation() + || matches!( + c, + ',' | '。' + | '?' + | '!' + | '、' + | ';' + | ':' + | '…' + | '\u{201C}' // " + | '\u{201D}' // " + | '\u{2018}' // ' + | '\u{2019}' // ' + | '—' + | '【' + | '】' + | '《' + | '》' + | '(' + | ')' + ) +} + +/// Return true if the token looks like an English/ASCII word or number. +fn is_english_token(token: &str) -> bool { + !token.is_empty() + && token + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '\'') +} + +/// Return true if the token consists entirely of CJK characters. +#[allow(dead_code)] +fn is_cjk_token(token: &str) -> bool { + !token.is_empty() && token.chars().all(is_cjk) +} + +/// Return true if the character is in a CJK Unicode block. +fn is_cjk(c: char) -> bool { + matches!(c, + '\u{4E00}'..='\u{9FFF}' // CJK Unified Ideographs + | '\u{3400}'..='\u{4DBF}' // CJK Extension A + | '\u{20000}'..='\u{2A6DF}' // CJK Extension B + | '\u{F900}'..='\u{FAFF}' // CJK Compatibility Ideographs + | '\u{2F800}'..='\u{2FA1F}' // CJK Compatibility Supplement + | '\u{3000}'..='\u{303F}' // CJK Symbols and Punctuation + | '\u{31F0}'..='\u{31FF}' // Katakana Phonetic Extensions + | '\u{3200}'..='\u{32FF}' // Enclosed CJK + | '\u{3300}'..='\u{33FF}' // CJK Compatibility + | '\u{AC00}'..='\u{D7AF}' // Hangul Syllables + ) +} From da593a3a83275033b1573a84b84b3270650d0441 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:46:44 +0800 Subject: [PATCH 07/12] feat: add examples and tests for Paraformer and Zipformer engines Add example binaries and integration tests for ParaformerModel, ZipformerCtcModel, and ZipformerTransducerModel following the existing gigaam/sense_voice patterns. Tests skip gracefully when model files are absent; examples accept positional args and --int8 flag. Co-Authored-By: Claude Sonnet 4.6 --- Cargo.toml | 24 +++++++++ examples/paraformer.rs | 87 ++++++++++++++++++++++++++++++++ examples/zipformer_ctc.rs | 87 ++++++++++++++++++++++++++++++++ examples/zipformer_transducer.rs | 87 ++++++++++++++++++++++++++++++++ tests/paraformer.rs | 31 ++++++++++++ tests/zipformer_ctc.rs | 31 ++++++++++++ tests/zipformer_transducer.rs | 32 ++++++++++++ 7 files changed, 379 insertions(+) create mode 100644 examples/paraformer.rs create mode 100644 examples/zipformer_ctc.rs create mode 100644 examples/zipformer_transducer.rs create mode 100644 tests/paraformer.rs create mode 100644 tests/zipformer_ctc.rs create mode 100644 tests/zipformer_transducer.rs diff --git a/Cargo.toml b/Cargo.toml index bda5202..3cd77a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,6 +118,18 @@ required-features = ["whisperfile"] name = "openai" required-features = ["openai"] +[[example]] +name = "paraformer" +required-features = ["onnx"] + +[[example]] +name = "zipformer_ctc" +required-features = ["onnx"] + +[[example]] +name = "zipformer_transducer" +required-features = ["onnx"] + [dev-dependencies] once_cell = "1.21.3" @@ -161,3 +173,15 @@ required-features = ["onnx", "vad-silero"] [[test]] name = "vad_silero" required-features = ["vad-silero"] + +[[test]] +name = "paraformer" +required-features = ["onnx"] + +[[test]] +name = "zipformer_ctc" +required-features = ["onnx"] + +[[test]] +name = "zipformer_transducer" +required-features = ["onnx"] diff --git a/examples/paraformer.rs b/examples/paraformer.rs new file mode 100644 index 0000000..550bf2a --- /dev/null +++ b/examples/paraformer.rs @@ -0,0 +1,87 @@ +use std::path::PathBuf; +use std::time::Instant; + +use transcribe_rs::onnx::paraformer::ParaformerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +fn get_audio_duration(path: &PathBuf) -> Result> { + let reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let duration = reader.duration() as f64 / spec.sample_rate as f64; + Ok(duration) +} + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = std::env::args().collect(); + let positional: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); + + let int8 = args.iter().any(|a| a == "--int8"); + let model_path = PathBuf::from( + positional + .first() + .map(|s| s.as_str()) + .unwrap_or("models/sherpa-onnx-paraformer-zh-2025-10-07"), + ); + let wav_path = PathBuf::from( + positional + .get(1) + .map(|s| s.as_str()) + .unwrap_or("samples/zh.wav"), + ); + + let audio_duration = get_audio_duration(&wav_path)?; + println!("Audio duration: {:.2}s", audio_duration); + + let quantization = if int8 { + Quantization::Int8 + } else { + Quantization::FP32 + }; + + println!("Using Paraformer engine"); + println!( + "Loading model: {:?} (quantization: {})", + model_path, + if int8 { "int8" } else { "fp32" } + ); + + let load_start = Instant::now(); + let mut model = ParaformerModel::load(&model_path, &quantization)?; + let load_duration = load_start.elapsed(); + println!("Model loaded in {:.2?}", load_duration); + + println!("Transcribing file: {:?}", wav_path); + let transcribe_start = Instant::now(); + + let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?; + let transcribe_duration = transcribe_start.elapsed(); + println!("Transcription completed in {:.2?}", transcribe_duration); + + let speedup_factor = audio_duration / transcribe_duration.as_secs_f64(); + println!( + "Real-time speedup: {:.2}x faster than real-time", + speedup_factor + ); + + println!("Transcription result:"); + println!("{}", result.text); + + if let Some(segments) = result.segments { + println!("\nSegments:"); + for segment in segments { + println!( + "[{:.2}s - {:.2}s]: {}", + segment.start, segment.end, segment.text + ); + } + } + + Ok(()) +} diff --git a/examples/zipformer_ctc.rs b/examples/zipformer_ctc.rs new file mode 100644 index 0000000..bcf2eec --- /dev/null +++ b/examples/zipformer_ctc.rs @@ -0,0 +1,87 @@ +use std::path::PathBuf; +use std::time::Instant; + +use transcribe_rs::onnx::zipformer_ctc::ZipformerCtcModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +fn get_audio_duration(path: &PathBuf) -> Result> { + let reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let duration = reader.duration() as f64 / spec.sample_rate as f64; + Ok(duration) +} + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = std::env::args().collect(); + let positional: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); + + let int8 = args.iter().any(|a| a == "--int8"); + let model_path = PathBuf::from( + positional + .first() + .map(|s| s.as_str()) + .unwrap_or("models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16"), + ); + let wav_path = PathBuf::from( + positional + .get(1) + .map(|s| s.as_str()) + .unwrap_or("samples/zh.wav"), + ); + + let audio_duration = get_audio_duration(&wav_path)?; + println!("Audio duration: {:.2}s", audio_duration); + + let quantization = if int8 { + Quantization::Int8 + } else { + Quantization::FP32 + }; + + println!("Using Zipformer CTC engine"); + println!( + "Loading model: {:?} (quantization: {})", + model_path, + if int8 { "int8" } else { "fp32" } + ); + + let load_start = Instant::now(); + let mut model = ZipformerCtcModel::load(&model_path, &quantization)?; + let load_duration = load_start.elapsed(); + println!("Model loaded in {:.2?}", load_duration); + + println!("Transcribing file: {:?}", wav_path); + let transcribe_start = Instant::now(); + + let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?; + let transcribe_duration = transcribe_start.elapsed(); + println!("Transcription completed in {:.2?}", transcribe_duration); + + let speedup_factor = audio_duration / transcribe_duration.as_secs_f64(); + println!( + "Real-time speedup: {:.2}x faster than real-time", + speedup_factor + ); + + println!("Transcription result:"); + println!("{}", result.text); + + if let Some(segments) = result.segments { + println!("\nSegments:"); + for segment in segments { + println!( + "[{:.2}s - {:.2}s]: {}", + segment.start, segment.end, segment.text + ); + } + } + + Ok(()) +} diff --git a/examples/zipformer_transducer.rs b/examples/zipformer_transducer.rs new file mode 100644 index 0000000..8067ba6 --- /dev/null +++ b/examples/zipformer_transducer.rs @@ -0,0 +1,87 @@ +use std::path::PathBuf; +use std::time::Instant; + +use transcribe_rs::onnx::zipformer_transducer::ZipformerTransducerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +fn get_audio_duration(path: &PathBuf) -> Result> { + let reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let duration = reader.duration() as f64 / spec.sample_rate as f64; + Ok(duration) +} + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = std::env::args().collect(); + let positional: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); + + let int8 = args.iter().any(|a| a == "--int8"); + let model_path = PathBuf::from( + positional + .first() + .map(|s| s.as_str()) + .unwrap_or("models/sherpa-onnx-zipformer-zh-en-2023-11-22"), + ); + let wav_path = PathBuf::from( + positional + .get(1) + .map(|s| s.as_str()) + .unwrap_or("samples/zh.wav"), + ); + + let audio_duration = get_audio_duration(&wav_path)?; + println!("Audio duration: {:.2}s", audio_duration); + + let quantization = if int8 { + Quantization::Int8 + } else { + Quantization::FP32 + }; + + println!("Using Zipformer Transducer engine"); + println!( + "Loading model: {:?} (quantization: {})", + model_path, + if int8 { "int8" } else { "fp32" } + ); + + let load_start = Instant::now(); + let mut model = ZipformerTransducerModel::load(&model_path, &quantization)?; + let load_duration = load_start.elapsed(); + println!("Model loaded in {:.2?}", load_duration); + + println!("Transcribing file: {:?}", wav_path); + let transcribe_start = Instant::now(); + + let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?; + let transcribe_duration = transcribe_start.elapsed(); + println!("Transcription completed in {:.2?}", transcribe_duration); + + let speedup_factor = audio_duration / transcribe_duration.as_secs_f64(); + println!( + "Real-time speedup: {:.2}x faster than real-time", + speedup_factor + ); + + println!("Transcription result:"); + println!("{}", result.text); + + if let Some(segments) = result.segments { + println!("\nSegments:"); + for segment in segments { + println!( + "[{:.2}s - {:.2}s]: {}", + segment.start, segment.end, segment.text + ); + } + } + + Ok(()) +} diff --git a/tests/paraformer.rs b/tests/paraformer.rs new file mode 100644 index 0000000..7741d37 --- /dev/null +++ b/tests/paraformer.rs @@ -0,0 +1,31 @@ +mod common; + +use std::path::PathBuf; + +use transcribe_rs::onnx::paraformer::ParaformerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +#[test] +fn test_paraformer_transcribe() { + env_logger::init(); + + let model_dir = PathBuf::from("models/sherpa-onnx-paraformer-zh-2025-10-07"); + let wav_path = PathBuf::from("samples/zh.wav"); + + if !common::require_paths(&[&model_dir, &wav_path]) { + return; + } + + let mut model = + ParaformerModel::load(&model_dir, &Quantization::Int8).expect("Failed to load model"); + + let result = model + .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) + .expect("Failed to transcribe"); + + assert!( + !result.text.is_empty(), + "Transcription result should not be empty" + ); +} diff --git a/tests/zipformer_ctc.rs b/tests/zipformer_ctc.rs new file mode 100644 index 0000000..99a18d9 --- /dev/null +++ b/tests/zipformer_ctc.rs @@ -0,0 +1,31 @@ +mod common; + +use std::path::PathBuf; + +use transcribe_rs::onnx::zipformer_ctc::ZipformerCtcModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +#[test] +fn test_zipformer_ctc_transcribe() { + env_logger::init(); + + let model_dir = PathBuf::from("models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16"); + let wav_path = PathBuf::from("samples/zh.wav"); + + if !common::require_paths(&[&model_dir, &wav_path]) { + return; + } + + let mut model = + ZipformerCtcModel::load(&model_dir, &Quantization::Int8).expect("Failed to load model"); + + let result = model + .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) + .expect("Failed to transcribe"); + + assert!( + !result.text.is_empty(), + "Transcription result should not be empty" + ); +} diff --git a/tests/zipformer_transducer.rs b/tests/zipformer_transducer.rs new file mode 100644 index 0000000..955c663 --- /dev/null +++ b/tests/zipformer_transducer.rs @@ -0,0 +1,32 @@ +mod common; + +use std::path::PathBuf; + +use transcribe_rs::onnx::zipformer_transducer::ZipformerTransducerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +#[test] +fn test_zipformer_transducer_transcribe() { + env_logger::init(); + + let model_dir = PathBuf::from("models/sherpa-onnx-zipformer-zh-en-2023-11-22"); + let wav_path = PathBuf::from("samples/zh.wav"); + + if !common::require_paths(&[&model_dir, &wav_path]) { + return; + } + + let mut model = + ZipformerTransducerModel::load(&model_dir, &Quantization::Int8) + .expect("Failed to load model"); + + let result = model + .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) + .expect("Failed to transcribe"); + + assert!( + !result.text.is_empty(), + "Transcription result should not be empty" + ); +} From e796f50b41b9c959b40f5393d8e61a4242f5a965 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Sun, 29 Mar 2026 18:49:08 +0800 Subject: [PATCH 08/12] chore: fix clippy warnings and cargo fmt - Fix loop variable indexing in kaldi_fbank.rs - Apply cargo fmt to all new files Co-Authored-By: Claude Opus 4.6 (1M context) --- src/decode/bbpe.rs | 40 +++++++------ src/decode/mod.rs | 2 +- src/features/kaldi_fbank.rs | 19 +++--- src/onnx/paraformer/mod.rs | 41 +++++++------ src/onnx/zipformer_ctc/mod.rs | 36 +++++------ src/onnx/zipformer_transducer/mod.rs | 89 ++++++++++++++++++---------- src/punct.rs | 13 ++-- tests/zipformer_transducer.rs | 5 +- 8 files changed, 132 insertions(+), 113 deletions(-) diff --git a/src/decode/bbpe.rs b/src/decode/bbpe.rs index 8d7af7a..cf3844d 100644 --- a/src/decode/bbpe.rs +++ b/src/decode/bbpe.rs @@ -46,7 +46,10 @@ impl BbpeSymbolTable { } /// Load with an explicitly specified encoding. - pub fn load_with_encoding(path: &Path, encoding: TokenEncoding) -> Result { + pub fn load_with_encoding( + path: &Path, + encoding: TokenEncoding, + ) -> Result { let contents = fs::read_to_string(path)?; let mut id_to_sym = HashMap::new(); for line in contents.lines() { @@ -62,7 +65,10 @@ impl BbpeSymbolTable { id_to_sym.insert(id, parts[1].to_string()); } } - Ok(Self { id_to_sym, encoding }) + Ok(Self { + id_to_sym, + encoding, + }) } /// Look up a symbol by token ID. @@ -141,23 +147,19 @@ fn normalize_text(text: &str) -> String { /// BBPE codepoint table: maps byte value (index) to Unicode codepoint. const BBPE_CODEPOINTS: [u32; 256] = [ - 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, - 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, - 286, 287, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, - 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, - 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, - 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 288, 289, 290, 291, 292, - 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 308, 309, - 310, 311, 312, 313, 314, 315, 316, 317, 318, 321, 322, 323, 324, 325, 326, - 327, 328, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, - 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, - 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, - 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, - 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, - 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, - 419, 420, 421, 422, + 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, + 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 308, + 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 321, 322, 323, 324, 325, 326, 327, 328, 330, + 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, + 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, + 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, + 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, + 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, ]; /// Convert a BBPE-encoded Unicode character back to its original byte value. diff --git a/src/decode/mod.rs b/src/decode/mod.rs index 9ff8655..24d5a98 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -3,7 +3,7 @@ mod ctc; mod sentencepiece; pub mod tokens; +pub use bbpe::BbpeSymbolTable; pub use ctc::{ctc_greedy_decode, CtcDecoderResult}; pub use sentencepiece::sentencepiece_to_text; -pub use bbpe::BbpeSymbolTable; pub use tokens::{load_vocab, SymbolTable}; diff --git a/src/features/kaldi_fbank.rs b/src/features/kaldi_fbank.rs index 9a48169..0274239 100644 --- a/src/features/kaldi_fbank.rs +++ b/src/features/kaldi_fbank.rs @@ -74,8 +74,7 @@ pub fn compute_kaldi_fbank(samples: &[f32], config: &KaldiFbankConfig) -> Array2 let window: Vec = (0..window_size) .map(|i| { let hamming = 0.54 - - 0.46 - * (2.0 * std::f32::consts::PI * i as f32 / (window_size as f32 - 1.0)).cos(); + - 0.46 * (2.0 * std::f32::consts::PI * i as f32 / (window_size as f32 - 1.0)).cos(); hamming.powf(0.85) }) .collect(); @@ -94,10 +93,10 @@ pub fn compute_kaldi_fbank(samples: &[f32], config: &KaldiFbankConfig) -> Array2 let start = center as isize - (window_size as isize / 2); let mut frame = vec![0.0f32; window_size]; - for i in 0..window_size { + for (i, sample) in frame.iter_mut().enumerate() { let idx = start + i as isize; if idx >= 0 && (idx as usize) < samples.len() { - frame[i] = samples[idx as usize]; + *sample = samples[idx as usize]; } } @@ -179,17 +178,13 @@ fn mel_filterbank(config: &KaldiFbankConfig) -> Vec> { let center = fft_bins[i + 1]; let right = fft_bins[i + 2]; if center > left { - for j in left..center { - if j < half_fft { - filter[j] = (j - left) as f32 / (center - left) as f32; - } + for (idx, val) in filter[left..center.min(half_fft)].iter_mut().enumerate() { + *val = idx as f32 / (center - left) as f32; } } if right > center { - for j in center..right { - if j < half_fft { - filter[j] = (right - j) as f32 / (right - center) as f32; - } + for (idx, val) in filter[center..right.min(half_fft)].iter_mut().enumerate() { + *val = (right - center - idx) as f32 / (right - center) as f32; } } } diff --git a/src/onnx/paraformer/mod.rs b/src/onnx/paraformer/mod.rs index bb396f3..688682e 100644 --- a/src/onnx/paraformer/mod.rs +++ b/src/onnx/paraformer/mod.rs @@ -123,9 +123,7 @@ impl ParaformerSymbolTable { let curr_is_ascii_word = is_ascii_word_piece(clean); let prev_is_ascii_word = prev_char.map(is_ascii_word_char).unwrap_or(false); let prev_is_cjk = prev_char.map(is_cjk).unwrap_or(false); - if curr_is_ascii_word - && (prev_is_ascii_word || prev_is_cjk) - && !text.ends_with(' ') + if curr_is_ascii_word && (prev_is_ascii_word || prev_is_cjk) && !text.ends_with(' ') { text.push(' '); } @@ -410,7 +408,11 @@ impl ParaformerModel { ); // Detect I/O names from session - let inputs: Vec = session.inputs().iter().map(|i| i.name().to_string()).collect(); + let inputs: Vec = session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); let outputs: Vec = session .outputs() .iter() @@ -421,7 +423,12 @@ impl ParaformerModel { .iter() .find(|n| n.contains("speech")) .cloned() - .unwrap_or_else(|| inputs.first().cloned().unwrap_or_else(|| "speech".to_string())); + .unwrap_or_else(|| { + inputs + .first() + .cloned() + .unwrap_or_else(|| "speech".to_string()) + }); let speech_lengths_input_name = inputs .iter() @@ -438,12 +445,14 @@ impl ParaformerModel { .iter() .find(|n| n.contains("logits")) .cloned() - .unwrap_or_else(|| outputs.first().cloned().unwrap_or_else(|| "logits".to_string())); + .unwrap_or_else(|| { + outputs + .first() + .cloned() + .unwrap_or_else(|| "logits".to_string()) + }); - let token_num_output_name = outputs - .iter() - .find(|n| n.contains("token_num")) - .cloned(); + let token_num_output_name = outputs.iter().find(|n| n.contains("token_num")).cloned(); log::debug!( "I/O names: speech={}, lengths={}, logits={}, token_num={:?}", @@ -554,16 +563,14 @@ impl ParaformerModel { } /// Run ONNX forward pass. Returns logits [1, T, vocab_size]. - fn forward( - &mut self, - features: &Array2, - ) -> Result, TranscribeError> { + fn forward(&mut self, features: &Array2) -> Result, TranscribeError> { let num_frames = features.nrows(); // Shape: [1, T, D] - let feat_3d = features - .to_owned() - .into_shape_with_order((1, num_frames, features.ncols()))?; + let feat_3d = + features + .to_owned() + .into_shape_with_order((1, num_frames, features.ncols()))?; let speech_lengths = ndarray::arr1(&[num_frames as i32]); let feat_dyn = feat_3d.into_dyn(); diff --git a/src/onnx/zipformer_ctc/mod.rs b/src/onnx/zipformer_ctc/mod.rs index ac2b7d5..40058d4 100644 --- a/src/onnx/zipformer_ctc/mod.rs +++ b/src/onnx/zipformer_ctc/mod.rs @@ -159,7 +159,10 @@ impl ZipformerCtcModel { /// Priority: /// 1. `session::resolve_model_path(dir, "model", quantization)` — standard naming /// 2. Scan directory for `*.int8.onnx` (when Int8 requested) or `*.onnx` - fn find_model_file(model_dir: &Path, quantization: &Quantization) -> Result { + fn find_model_file( + model_dir: &Path, + quantization: &Quantization, + ) -> Result { // Try standard path first let standard_path = session::resolve_model_path(model_dir, "model", quantization); if standard_path.exists() { @@ -182,10 +185,7 @@ impl ZipformerCtcModel { for entry in read_dir.flatten() { let path = entry.path(); if path.extension().and_then(|e| e.to_str()) == Some("onnx") { - let name = path - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or(""); + let name = path.file_name().and_then(|n| n.to_str()).unwrap_or(""); if name.contains("int8") || name.contains("int4") { int8_candidates.push(path); } else { @@ -211,9 +211,7 @@ impl ZipformerCtcModel { } // Last resort: return the int8 candidate even if fp32 preferred - Err(TranscribeError::ModelNotFound( - model_dir.join("model.onnx"), - )) + Err(TranscribeError::ModelNotFound(model_dir.join("model.onnx"))) } /// Transcribe with model-specific parameters. @@ -236,11 +234,7 @@ impl ZipformerCtcModel { }); } - log::debug!( - "Kaldi fbank: [{}, {}]", - features.nrows(), - features.ncols() - ); + log::debug!("Kaldi fbank: [{}, {}]", features.nrows(), features.ncols()); // 2. Run ONNX forward pass → log_probs [1, T', vocab] let (log_probs, output_len) = self.forward(&features)?; @@ -269,13 +263,17 @@ impl ZipformerCtcModel { /// /// Returns `(log_probs [1, T, vocab], output_len)` where `output_len` is /// the valid frame count for batch element 0. - fn forward(&mut self, features: &Array2) -> Result<(ndarray::Array3, i64), TranscribeError> { + fn forward( + &mut self, + features: &Array2, + ) -> Result<(ndarray::Array3, i64), TranscribeError> { let num_frames = features.nrows() as i64; // Shape: [1, T, 80] - let feat_3d = features - .to_owned() - .into_shape_with_order((1, features.nrows(), features.ncols()))?; + let feat_3d = + features + .to_owned() + .into_shape_with_order((1, features.nrows(), features.ncols()))?; let x_lens = ndarray::arr1(&[num_frames]); let feat_dyn = feat_3d.into_dyn(); @@ -293,9 +291,7 @@ impl ZipformerCtcModel { // Extract log_probs — always the first output, shape [1, T', vocab] let log_probs = outputs[0].try_extract_array::()?; - let log_probs = log_probs - .to_owned() - .into_dimensionality::()?; + let log_probs = log_probs.to_owned().into_dimensionality::()?; // Determine output length: use the length output if available, else T' let output_len = if let Some(len_idx) = self.log_probs_len_output_idx { diff --git a/src/onnx/zipformer_transducer/mod.rs b/src/onnx/zipformer_transducer/mod.rs index 3591e49..9eef497 100644 --- a/src/onnx/zipformer_transducer/mod.rs +++ b/src/onnx/zipformer_transducer/mod.rs @@ -135,29 +135,51 @@ impl ZipformerTransducerModel { .map(|o| o.name().to_string()) .collect(); - log::debug!("Encoder inputs: {:?}, outputs: {:?}", enc_input_names, enc_output_names); - log::debug!("Decoder inputs: {:?}, outputs: {:?}", dec_input_names, dec_output_names); - log::debug!("Joiner inputs: {:?}, outputs: {:?}", join_input_names, join_output_names); + log::debug!( + "Encoder inputs: {:?}, outputs: {:?}", + enc_input_names, + enc_output_names + ); + log::debug!( + "Decoder inputs: {:?}, outputs: {:?}", + dec_input_names, + dec_output_names + ); + log::debug!( + "Joiner inputs: {:?}, outputs: {:?}", + join_input_names, + join_output_names + ); // Detect encoder I/O names let enc_x_name = Self::find_name(&enc_input_names, &["x", "features", "input"]) - .unwrap_or_else(|| enc_input_names.first().cloned().unwrap_or_else(|| "x".to_string())); + .unwrap_or_else(|| { + enc_input_names + .first() + .cloned() + .unwrap_or_else(|| "x".to_string()) + }); let enc_x_lens_name = Self::find_name(&enc_input_names, &["x_lens", "x_length", "input_lengths"]) .unwrap_or_else(|| { - enc_input_names.get(1).cloned().unwrap_or_else(|| "x_lens".to_string()) - }); - - let enc_out_name = - Self::find_name(&enc_output_names, &["encoder_out", "output", "encoder_output"]) - .unwrap_or_else(|| { - enc_output_names - .first() + enc_input_names + .get(1) .cloned() - .unwrap_or_else(|| "encoder_out".to_string()) + .unwrap_or_else(|| "x_lens".to_string()) }); + let enc_out_name = Self::find_name( + &enc_output_names, + &["encoder_out", "output", "encoder_output"], + ) + .unwrap_or_else(|| { + enc_output_names + .first() + .cloned() + .unwrap_or_else(|| "encoder_out".to_string()) + }); + let enc_out_lens_name = Self::find_name( &enc_output_names, &["encoder_out_lens", "encoder_out_length", "output_lengths"], @@ -171,16 +193,23 @@ impl ZipformerTransducerModel { // Detect decoder I/O names let dec_y_name = Self::find_name(&dec_input_names, &["y", "input", "decoder_input"]) - .unwrap_or_else(|| dec_input_names.first().cloned().unwrap_or_else(|| "y".to_string())); + .unwrap_or_else(|| { + dec_input_names + .first() + .cloned() + .unwrap_or_else(|| "y".to_string()) + }); - let dec_out_name = - Self::find_name(&dec_output_names, &["decoder_out", "output", "decoder_output"]) - .unwrap_or_else(|| { - dec_output_names - .first() - .cloned() - .unwrap_or_else(|| "decoder_out".to_string()) - }); + let dec_out_name = Self::find_name( + &dec_output_names, + &["decoder_out", "output", "decoder_output"], + ) + .unwrap_or_else(|| { + dec_output_names + .first() + .cloned() + .unwrap_or_else(|| "decoder_out".to_string()) + }); // Detect joiner I/O names let join_enc_name = Self::find_name( @@ -354,11 +383,7 @@ impl ZipformerTransducerModel { }); } - log::debug!( - "Kaldi fbank: [{}, {}]", - features.nrows(), - features.ncols() - ); + log::debug!("Kaldi fbank: [{}, {}]", features.nrows(), features.ncols()); // 2. RNN-T greedy search let token_ids = self.greedy_search(&features)?; @@ -378,9 +403,10 @@ impl ZipformerTransducerModel { features: &Array2, ) -> Result<(Array3, i64), TranscribeError> { let num_frames = features.nrows(); - let feat_3d = features - .to_owned() - .into_shape_with_order((1, num_frames, features.ncols()))?; + let feat_3d = + features + .to_owned() + .into_shape_with_order((1, num_frames, features.ncols()))?; let lens = ndarray::arr1(&[num_frames as i64]).into_dyn(); let feat_dyn = feat_3d.into_dyn(); @@ -416,8 +442,7 @@ impl ZipformerTransducerModel { /// Run decoder: y [1, context_size] (i64) -> decoder_out [1, D] fn run_decoder(&mut self, context: &[i64]) -> Result, TranscribeError> { - let y = - ndarray::Array2::from_shape_vec((1, self.context_size), context.to_vec())?; + let y = ndarray::Array2::from_shape_vec((1, self.context_size), context.to_vec())?; let y_dyn = y.into_dyn(); let t_y = TensorRef::from_array_view(y_dyn.view())?; diff --git a/src/punct.rs b/src/punct.rs index e7e4d27..16b2bff 100644 --- a/src/punct.rs +++ b/src/punct.rs @@ -235,11 +235,9 @@ impl PunctModel { /// Returns a Vec of punctuation class IDs, one per input token. fn run_inference(&mut self, token_ids: &[i32]) -> Result, TranscribeError> { let seq_len = token_ids.len(); - let input_array = Array2::from_shape_vec( - (1, seq_len), - token_ids.iter().map(|&x| x as i64).collect(), - ) - .map_err(|e| TranscribeError::Inference(format!("shape error: {e}")))?; + let input_array = + Array2::from_shape_vec((1, seq_len), token_ids.iter().map(|&x| x as i64).collect()) + .map_err(|e| TranscribeError::Inference(format!("shape error: {e}")))?; let length_array = Array1::from_vec(vec![seq_len as i64]); let input_tensor = TensorRef::from_array_view(input_array.view()) @@ -366,10 +364,7 @@ impl PunctModel { fn choose_punct_char(pt: PunctType, current_word: &str, result_so_far: &str) -> Option { // If the current word is an English/ASCII word, use ASCII punctuation. // If the preceding content ends in a CJK character, use full-width. - let last_meaningful = result_so_far - .chars() - .rev() - .find(|c| !c.is_whitespace()); + let last_meaningful = result_so_far.chars().rev().find(|c| !c.is_whitespace()); let use_ascii = is_english_token(current_word) || last_meaningful diff --git a/tests/zipformer_transducer.rs b/tests/zipformer_transducer.rs index 955c663..646546b 100644 --- a/tests/zipformer_transducer.rs +++ b/tests/zipformer_transducer.rs @@ -17,9 +17,8 @@ fn test_zipformer_transducer_transcribe() { return; } - let mut model = - ZipformerTransducerModel::load(&model_dir, &Quantization::Int8) - .expect("Failed to load model"); + let mut model = ZipformerTransducerModel::load(&model_dir, &Quantization::Int8) + .expect("Failed to load model"); let result = model .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) From ff86ecd5fa9633f554ea7d34c9ccba1b883e536d Mon Sep 17 00:00:00 2001 From: Pengfei Date: Mon, 30 Mar 2026 07:49:18 +0800 Subject: [PATCH 09/12] Fix punct model int64/int32 type mismatch The CT-Transformer punctuation model expects int32 input tensors, but the code was casting token IDs from i32 to i64. Use i32 directly for both input_array and length_array. Also make output extraction flexible (try i64 first, fall back to i32) since different model versions may output different types. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/punct.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/punct.rs b/src/punct.rs index 16b2bff..386ed6b 100644 --- a/src/punct.rs +++ b/src/punct.rs @@ -235,10 +235,9 @@ impl PunctModel { /// Returns a Vec of punctuation class IDs, one per input token. fn run_inference(&mut self, token_ids: &[i32]) -> Result, TranscribeError> { let seq_len = token_ids.len(); - let input_array = - Array2::from_shape_vec((1, seq_len), token_ids.iter().map(|&x| x as i64).collect()) - .map_err(|e| TranscribeError::Inference(format!("shape error: {e}")))?; - let length_array = Array1::from_vec(vec![seq_len as i64]); + let input_array = Array2::from_shape_vec((1, seq_len), token_ids.to_vec()) + .map_err(|e| TranscribeError::Inference(format!("shape error: {e}")))?; + let length_array = Array1::from_vec(vec![seq_len as i32]); let input_tensor = TensorRef::from_array_view(input_array.view()) .map_err(|e| TranscribeError::Inference(format!("input tensor: {e}")))?; @@ -253,10 +252,15 @@ impl PunctModel { ]) .map_err(|e| TranscribeError::Inference(format!("inference: {e}")))?; - let output = outputs[0] - .try_extract_array::() - .map_err(|e| TranscribeError::Inference(format!("extract output: {e}")))?; - let punct_ids: Vec = output.iter().map(|&x| x as usize).collect(); + // Try i64 first (common for ONNX argmax output), fall back to i32 + let punct_ids: Vec = if let Ok(output) = outputs[0].try_extract_array::() { + output.iter().map(|&x| x as usize).collect() + } else { + let output = outputs[0] + .try_extract_array::() + .map_err(|e| TranscribeError::Inference(format!("extract output: {e}")))?; + output.iter().map(|&x| x as usize).collect() + }; Ok(punct_ids) } From e40c45fa2a133ec0bff2dfc583538e996d9027fa Mon Sep 17 00:00:00 2001 From: Pengfei Date: Mon, 30 Mar 2026 07:55:21 +0800 Subject: [PATCH 10/12] Fix punct output extraction: argmax f32 logits instead of casting ints The CT-Transformer punct model outputs float32 logits with shape [batch, seq_len, num_classes=6], not pre-argmaxed integers. Apply argmax along the last axis to get punctuation class IDs. Fall back to i64/i32 extraction for models that output pre-argmaxed values. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/punct.rs | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/punct.rs b/src/punct.rs index 386ed6b..35ee7cb 100644 --- a/src/punct.rs +++ b/src/punct.rs @@ -252,8 +252,32 @@ impl PunctModel { ]) .map_err(|e| TranscribeError::Inference(format!("inference: {e}")))?; - // Try i64 first (common for ONNX argmax output), fall back to i32 - let punct_ids: Vec = if let Ok(output) = outputs[0].try_extract_array::() { + // Output is logits [batch=1, seq_len, num_classes] — argmax along last axis + let punct_ids: Vec = if let Ok(logits) = outputs[0].try_extract_array::() { + let shape = logits.shape(); + if shape.len() == 3 { + // [1, seq_len, num_classes] → argmax per token + let num_classes = shape[2]; + logits + .as_slice() + .unwrap() + .chunks(num_classes) + .skip(0) // batch dim handled by taking first batch only + .take(seq_len) + .map(|row| { + row.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .map(|(idx, _)| idx) + .unwrap_or(0) + }) + .collect() + } else { + // Unexpected shape, return zeros + vec![0usize; seq_len] + } + } else if let Ok(output) = outputs[0].try_extract_array::() { + // Some models output pre-argmaxed int64 output.iter().map(|&x| x as usize).collect() } else { let output = outputs[0] From 006726cca4b5b80478011ee332ee4437c5fdc553 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Mon, 30 Mar 2026 08:31:06 +0800 Subject: [PATCH 11/12] Downgrade verbose init logs from INFO to DEBUG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reduce log noise during normal operation: - ONNX session model input/output tensor info → DEBUG - BBPE encoding detection → DEBUG - Punct model token count and input names → DEBUG - Zipformer model file discovery → DEBUG Error and warning logs (model load failures, inference errors) remain at WARN/ERROR level for visibility. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/llm-postprocess-quickstart.md | 90 ++ docs/post-processing-analysis.md | 122 +++ .../plans/2026-03-29-port-sherpa-engines.md | 942 ++++++++++++++++++ examples/llm_postprocess.rs | 89 ++ src/decode/bbpe.rs | 4 +- src/llm_postprocess.rs | 273 +++++ src/onnx/session.rs | 4 +- src/onnx/zipformer_ctc/mod.rs | 4 +- src/onnx/zipformer_transducer/mod.rs | 4 +- src/punct.rs | 4 +- 10 files changed, 1526 insertions(+), 10 deletions(-) create mode 100644 docs/llm-postprocess-quickstart.md create mode 100644 docs/post-processing-analysis.md create mode 100644 docs/superpowers/plans/2026-03-29-port-sherpa-engines.md create mode 100644 examples/llm_postprocess.rs create mode 100644 src/llm_postprocess.rs diff --git a/docs/llm-postprocess-quickstart.md b/docs/llm-postprocess-quickstart.md new file mode 100644 index 0000000..e7dfb5d --- /dev/null +++ b/docs/llm-postprocess-quickstart.md @@ -0,0 +1,90 @@ +# LLM 后处理快速测试指南 + +## 环境准备 + +### 1. 下载模型文件 + +```bash +mkdir -p models/qwen2.5-0.5b + +# 下载 GGUF 量化模型(约 350MB) +# 从 HuggingFace 下载 Qwen2.5-0.5B-Instruct 的 GGUF 版本 +# 推荐 q4_k_m 量化(质量与大小平衡) +wget -O models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf \ + "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf" + +# 下载 tokenizer.json(GGUF 不含 tokenizer,需单独下载) +wget -O models/qwen2.5-0.5b/tokenizer.json \ + "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" +``` + +### 2. 验证文件 + +```bash +ls -lh models/qwen2.5-0.5b/ +# 预期: +# qwen2.5-0.5b-instruct-q4_k_m.gguf ~350MB +# tokenizer.json ~7MB +``` + +## 依赖说明 + +| 依赖 | 版本 | 用途 | +|------|------|------| +| candle-core | 0.9.2 | 张量运算 | +| candle-nn | 0.9.2 | 神经网络层 | +| candle-transformers | 0.9.2 | Qwen2 模型架构 + GGUF 加载 | +| tokenizers | 0.22 | HuggingFace tokenizer 加载 | + +所有依赖均为纯 Rust(CPU 路径),无需安装任何 C/C++ 库。 + +## 编译 + +```bash +# 编译 example(首次编译 candle 约 2-3 分钟) +cargo build --example llm_postprocess --features llm-postprocess +``` + +## 运行 + +```bash +# 使用默认测试文本 +cargo run --example llm_postprocess --features llm-postprocess + +# 自定义输入 +cargo run --example llm_postprocess --features llm-postprocess -- \ + --model models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf \ + --tokenizer models/qwen2.5-0.5b/tokenizer.json \ + --text "今天天气很好我们去公圆玩吧他说号的" +``` + +### 预期输出 + +``` +原始文本: 今天天气很好我们去公圆玩吧他说号的 +修正文本: 今天天气很好,我们去公园玩吧。他说好的。 +生成 tokens: 25 +耗时: 2.34s +速度: 10.7 tok/s +``` + +> 性能因硬件而异。Apple Silicon (M1/M2) CPU 约 30-60 tok/s,x86 可能更慢。 + +## 已知限制 + +1. **仅 CPU 推理**:当前未启用 Metal/CUDA 加速,纯 CPU 运行 +2. **GGUF tokenizer 缺失**:candle 从 GGUF 文件只读取权重,tokenizer 需单独提供 `tokenizer.json` +3. **首次加载较慢**:模型加载约 1-2s(磁盘 I/O),后续推理约 2-5s/句 +4. **中文为主**:Qwen2.5 对中文纠错效果好,英文效果取决于上下文 +5. **非流式**:当前实现为完整生成后输出,非逐 token 流式 +6. **max_tokens 固定**:默认最大生成 256 tokens,超长文本需分句处理 + +## 故障排查 + +| 问题 | 解决方案 | +|------|---------| +| 编译错误 `candle-core not found` | 确认使用 `--features llm-postprocess` | +| 运行时 `model file not found` | 检查模型路径是否正确 | +| 运行时 `tokenizer.json not found` | 需单独下载 tokenizer,见上方步骤 | +| 输出乱码 | 检查 GGUF 文件是否完整下载 | +| 内存不足 | q4_0 需 ~280MB RAM,q4_k_m 需 ~350MB | diff --git a/docs/post-processing-analysis.md b/docs/post-processing-analysis.md new file mode 100644 index 0000000..3f25108 --- /dev/null +++ b/docs/post-processing-analysis.md @@ -0,0 +1,122 @@ +# ASR 后处理增强分析 + +## 1. 当前架构 + +transcribe-rs 使用 **ort (ONNX Runtime)** 静态编译,支持多种 ASR 引擎(Paraformer、SenseVoice、Zipformer 等), +后处理目前仅有 CT-Transformer 标点恢复(`punct.rs`)。 + +| 维度 | ort + ONNX(当前) | Vosk(已弃用) | +|------|-------------------|---------------| +| 链接方式 | Rust 静态编译,零动态库 | 需要 libvosk.dylib/so | +| 部署复杂度 | 单二进制 + 模型文件 | 二进制 + 动态库 + 模型 | +| 跨平台 | macOS/Linux/Windows 统一 | 每个平台单独编译动态库 | +| 推理性能 | ONNX 优化图,CPU 7ms/句 | 相当 | + +**结论**:ort 静态编译路线已被验证,后续扩展应保持零动态库依赖。 + +--- + +## 2. 后处理方案对比 + +### 2.1 专用标点模型 + +| 模型 | 语言 | 大小 | 延迟 (CPU) | 能力 | +|------|------|------|-----------|------| +| CT-Transformer(当前) | 中英 | ~50MB (int8) | ~7ms/句 | 标点恢复 | +| punct_cap_seg_47lang | 47语言 | ~500MB | ~20ms/句 | 标点 + 大小写 + 分句 | + +### 2.2 小型 LLM + +| 模型 | 参数量 | 量化大小 | 延迟 (CPU) | 能力 | +|------|--------|---------|-----------|------| +| Qwen2.5-0.5B-Instruct | 0.5B | 280MB (q4_0) / 350MB (q4_k_m) | 2-5s/句 | 标点 + 纠错 + 语义修正 | +| Phi-4-mini | 3.8B | ~2.2GB (q4) | 10-30s/句 | 更强纠错,但太慢太大 | + +### 2.3 对比结论 + +| 维度 | CT-Transformer | Qwen2.5-0.5B | +|------|---------------|---------------| +| 速度 | 极快 (~7ms) | 较慢 (~2-5s) | +| 内存 | ~100MB | ~280-350MB | +| 能力 | 仅标点 | 标点 + 纠错 + 语义 | +| 质量 | 标点准确率高 | 可纠正同音字、语法错误 | +| 适用场景 | 实时/批量 | 离线/高质量后处理 | + +--- + +## 3. Rust 推理框架对比 + +| 框架 | 语言 | 链接方式 | GGUF 支持 | 成熟度 | +|------|------|---------|----------|--------| +| **candle** (0.9.2) | 纯 Rust | 静态,零 C 依赖 | 原生支持 | HuggingFace 官方 | +| llama-cpp-rs | Rust bindings → C++ | 静态链接 llama.cpp | 原生 | 成熟但引入 C++ | +| ort(现有) | Rust bindings → ONNX Runtime | 静态链接 | 不支持 GGUF | 项目已在用 | + +**选择 candle 的理由**: +- 纯 Rust,与项目零动态库策略一致 +- 原生支持 GGUF 量化格式(`candle_transformers::quantized`) +- 内置 Qwen2 架构支持(`quantized_qwen2::ModelWeights`) +- HuggingFace 官方维护,API 稳定 +- 与现有 ort 依赖无冲突(candle CPU 路径纯 Rust) + +--- + +## 4. 推荐架构:两层管线 + +``` +ASR 原始文本 + │ + ▼ +┌─────────────────────┐ +│ 第一层:CT-Transformer │ ~7ms,始终运行 +│ (标点恢复) │ +└─────────┬───────────┘ + │ + ▼ +┌─────────────────────┐ +│ 第二层:Qwen2.5-0.5B │ ~2-5s,可选 +│ (纠错 + 语义修正) │ +└─────────┬───────────┘ + │ + ▼ + 最终输出文本 +``` + +### 设计原则 + +1. **第一层始终运行**:CT-Transformer 速度极快,开销可忽略 +2. **第二层可选启用**:通过 feature flag `llm-postprocess` 控制编译 +3. **独立依赖**:candle 依赖与 ort 依赖完全隔离,互不影响 +4. **渐进式增强**:不修改现有 `TranscriptionEngine` trait 和 `punct.rs` + +--- + +## 5. 延迟 / 内存 / 大小对比 + +| 指标 | CT-Transformer | Qwen2.5-0.5B (q4_0) | Qwen2.5-0.5B (q4_k_m) | 两层合计 | +|------|---------------|---------------------|----------------------|---------| +| 模型文件 | ~50MB | ~280MB | ~350MB | ~330-400MB | +| 运行内存 | ~100MB | ~280MB | ~350MB | ~380-450MB | +| 推理延迟 | ~7ms | ~2-3s | ~2-5s | ~2-5s | +| 吞吐量 | ~30-60 tok/s | ~30-60 tok/s (M1) | ~30-60 tok/s (M1) | N/A | +| 编译产物增量 | 已含 | +5-8MB | +5-8MB | +5-8MB | + +> 以上数据基于 Apple Silicon (M1/M2) CPU 推理,x86 平台可能略慢。 + +--- + +## 6. 实施路线 + +### Phase 1(当前) +- 创建独立 example 验证 candle + Qwen2.5-0.5B 可行性 +- 不修改 lib.rs,不影响现有功能 + +### Phase 2(后续) +- 将 LLM 后处理封装为 `llm_postprocess.rs` 模块 +- 集成到 `TranscriptionEngine` trait 的后处理管线 +- 支持通过配置切换是否启用 + +### Phase 3(远期) +- 探索 candle Metal/CUDA 加速 +- 评估 Qwen2.5-1.5B 或更大模型的效果 +- 考虑流式后处理(逐句修正) diff --git a/docs/superpowers/plans/2026-03-29-port-sherpa-engines.md b/docs/superpowers/plans/2026-03-29-port-sherpa-engines.md new file mode 100644 index 0000000..140dac4 --- /dev/null +++ b/docs/superpowers/plans/2026-03-29-port-sherpa-engines.md @@ -0,0 +1,942 @@ +# Port Paraformer/Zipformer Engines to Upstream Architecture + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Port Paraformer, Zipformer CTC, Zipformer Transducer engines and punctuation model from our fork's old `TranscriptionEngine` trait to upstream's `SpeechModel` trait architecture, enabling PR merge into cjpais/transcribe-rs. + +**Architecture:** Create a feature branch from `upstream/main` (v0.3.5). Add shared Kaldi fbank and BBPE decode modules. Port each engine as a new module under `src/onnx/`. Port punct.rs as a standalone feature. All engines use upstream's `session.rs`, `TranscribeError`, and `Quantization`. + +**Tech Stack:** Rust, ort 2.0.0-rc.12, ndarray 0.17, rustfft 6, serde/serde_json + +--- + +## File Structure + +### New files to create: +- `src/features/kaldi_fbank.rs` — Kaldi-compatible fbank (Povey window, DC removal, preemphasis, neg high_freq) +- `src/decode/bbpe.rs` — BBPE byte-to-unicode symbol table + text normalization +- `src/onnx/paraformer/mod.rs` — ParaformerModel + SpeechModel impl +- `src/onnx/zipformer_ctc/mod.rs` — ZipformerCtcModel + SpeechModel impl +- `src/onnx/zipformer_transducer/mod.rs` — ZipformerTransducerModel + SpeechModel impl +- `src/punct.rs` — PunctModel (standalone punctuation post-processor) +- `examples/paraformer.rs` — Paraformer example +- `examples/zipformer_ctc.rs` — Zipformer CTC example +- `examples/zipformer_transducer.rs` — Zipformer Transducer example +- `tests/paraformer.rs` — Paraformer test +- `tests/zipformer_ctc.rs` — Zipformer CTC test +- `tests/zipformer_transducer.rs` — Zipformer Transducer test + +### Files to modify: +- `src/features/mod.rs` — add `pub mod kaldi_fbank;` +- `src/decode/mod.rs` — add `pub mod bbpe;` +- `src/onnx/mod.rs` — add `pub mod paraformer; pub mod zipformer_ctc; pub mod zipformer_transducer;` +- `src/lib.rs` — add `pub mod punct;` under punct feature +- `src/error.rs` — no changes needed (already has ort::Error and serde_json::Error From impls) +- `Cargo.toml` — add `punct` feature, examples, tests + +### Reference files (read from backup branch): +- `backup/pre-upstream-port:src/engines/paraformer/model.rs` — Paraformer inference logic +- `backup/pre-upstream-port:src/engines/paraformer/features.rs` — Paraformer fbank (simpler, non-Kaldi) +- `backup/pre-upstream-port:src/engines/paraformer/tokens.rs` — Paraformer symbol table +- `backup/pre-upstream-port:src/engines/zipformer_common.rs` — Kaldi fbank + BBPE + SymbolTable +- `backup/pre-upstream-port:src/engines/zipformer_ctc/model.rs` — CTC inference +- `backup/pre-upstream-port:src/engines/zipformer_transducer/model.rs` — Transducer inference +- `backup/pre-upstream-port:src/punct.rs` — Punctuation model + +--- + +## Key API Mapping (old → new) + +### ort rc.10 → rc.12 +- `session.inputs` → `session.inputs()` +- `session.outputs` → `session.outputs()` +- `input.name` → `input.name()` +- `input.input_type` → `input.dtype()` +- `output.name` → `output.name()` +- `output.output_type` → `output.dtype()` +- `CPUExecutionProvider::default().build()` → use `session::create_session()` (handles all EPs) +- `metadata.custom(key)?` returns `Option` (no Result wrapping in rc.12; use `session::read_metadata_str`) + +### Trait mapping +- `TranscriptionEngine::transcribe_samples(samples, params)` → `SpeechModel::transcribe_raw(samples, &TranscribeOptions)` +- `Box` → `TranscribeError` +- `ParaformerModel::new(dir, quantized: bool)` → `ParaformerModel::load(dir, &Quantization)` +- Custom error enums → `TranscribeError::{ModelNotFound, Inference, Config, ...}` + +### Feature extraction mapping +- Paraformer uses standard fbank (Hamming window, dB scale) — use upstream `compute_mel()` with appropriate `MelConfig` +- Zipformer uses Kaldi fbank (Povey window, natural log, DC removal) — use new `kaldi_fbank.rs` +- LFR/CMVN — use upstream `features::apply_lfr` and `features::apply_cmvn` + +--- + +### Task 1: Create feature branch from upstream/main + +**Files:** None (git operations only) + +- [ ] **Step 1: Create feature branch** + +```bash +git checkout -b feat/sherpa-engines upstream/main +``` + +- [ ] **Step 2: Verify clean state** + +```bash +cargo check --features onnx +``` + +Expected: compiles clean on upstream/main + +- [ ] **Step 3: Commit (empty, branch marker)** + +No commit needed — clean branch from upstream. + +--- + +### Task 2: Add Kaldi fbank feature extraction + +**Files:** +- Create: `src/features/kaldi_fbank.rs` +- Modify: `src/features/mod.rs` + +- [ ] **Step 1: Create `src/features/kaldi_fbank.rs`** + +Port from `backup/pre-upstream-port:src/engines/zipformer_common.rs` (the `compute_fbank_kaldi` function and `FbankConfig`), adapting to upstream style: + +```rust +//! Kaldi-compatible FBank feature extraction. +//! +//! Matches the behavior of kaldi-native-fbank / sherpa-onnx for Zipformer +//! and Paraformer models that expect Kaldi-style features. + +use ndarray::Array2; +use rustfft::{num_complex::Complex, FftPlanner}; + +/// Kaldi-compatible FBank configuration. +#[derive(Debug, Clone)] +pub struct KaldiFbankConfig { + pub num_bins: usize, + pub fft_size: usize, + pub window_size: usize, + pub hop_size: usize, + pub sample_rate: u32, + pub low_freq: f32, + /// Negative means nyquist + high_freq (Kaldi convention). -400 → 7600 Hz at 16 kHz. + pub high_freq: f32, + pub preemph_coeff: f32, + pub snip_edges: bool, + pub remove_dc_offset: bool, +} + +impl Default for KaldiFbankConfig { + fn default() -> Self { + Self { + num_bins: 80, + fft_size: 512, + window_size: 400, + hop_size: 160, + sample_rate: 16000, + low_freq: 20.0, + high_freq: -400.0, + preemph_coeff: 0.97, + snip_edges: false, + remove_dc_offset: true, + } + } +} + +/// Compute Kaldi-compatible FBank features. +/// +/// Key differences from standard mel spectrogram: +/// - Povey window (Hamming^0.85) instead of plain Hamming/Hann +/// - DC offset removal per frame +/// - Preemphasis applied per frame (reverse order) +/// - snip_edges=false centers first frame and zero-pads boundaries +/// - Natural log energy (not dB) +/// - Negative high_freq interpreted as nyquist + value +/// +/// Returns `[num_frames, num_bins]`. +pub fn compute_kaldi_fbank(samples: &[f32], config: &KaldiFbankConfig) -> Array2 { + let window_size = config.window_size; + let hop_size = config.hop_size; + let fft_size = config.fft_size; + let half_fft = fft_size / 2 + 1; + + if samples.is_empty() { + return Array2::zeros((0, config.num_bins)); + } + + let num_frames = if config.snip_edges { + if samples.len() < window_size { + return Array2::zeros((0, config.num_bins)); + } + (samples.len() - window_size) / hop_size + 1 + } else { + (samples.len() + hop_size / 2) / hop_size + }; + + if num_frames == 0 { + return Array2::zeros((0, config.num_bins)); + } + + let nyquist = config.sample_rate as f32 / 2.0; + let high_freq = if config.high_freq <= 0.0 { + nyquist + config.high_freq + } else { + config.high_freq + }; + + let filterbank = mel_filterbank(config.num_bins, fft_size, config.sample_rate as f32, config.low_freq, high_freq); + + // Povey window: hamming^0.85 + let window: Vec = (0..window_size) + .map(|i| { + let hamming = 0.54 + - 0.46 + * (2.0 * std::f32::consts::PI * i as f32 / (window_size as f32 - 1.0)).cos(); + hamming.powf(0.85) + }) + .collect(); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(fft_size); + + let mut features = Vec::with_capacity(num_frames * config.num_bins); + + for frame_idx in 0..num_frames { + let center = if config.snip_edges { + frame_idx * hop_size + window_size / 2 + } else { + frame_idx * hop_size + }; + let start = center as isize - (window_size as isize / 2); + + // Extract frame with zero-padding at boundaries + let mut frame = vec![0.0f32; window_size]; + for i in 0..window_size { + let idx = start + i as isize; + if idx >= 0 && (idx as usize) < samples.len() { + frame[i] = samples[idx as usize]; + } + } + + // Remove DC offset + if config.remove_dc_offset { + let mean: f32 = frame.iter().sum::() / window_size as f32; + for s in frame.iter_mut() { + *s -= mean; + } + } + + // Preemphasis (reverse order to avoid overwriting) + if config.preemph_coeff > 0.0 { + for i in (1..window_size).rev() { + frame[i] -= config.preemph_coeff * frame[i - 1]; + } + frame[0] *= 1.0 - config.preemph_coeff; + } + + // Apply window and FFT + let mut buffer: Vec> = frame + .iter() + .zip(window.iter()) + .map(|(&s, &w)| Complex::new(s * w, 0.0)) + .collect(); + buffer.resize(fft_size, Complex::new(0.0, 0.0)); + fft.process(&mut buffer); + + // Power spectrum + let power: Vec = buffer[..half_fft].iter().map(|c| c.norm_sqr()).collect(); + + // Apply mel filterbank and take natural log + for filter in &filterbank { + let energy: f32 = filter.iter().zip(power.iter()).map(|(&w, &p)| w * p).sum(); + features.push(if energy > f32::EPSILON { + energy.ln() + } else { + f32::EPSILON.ln() + }); + } + } + + Array2::from_shape_vec((num_frames, config.num_bins), features).unwrap() +} + +fn mel_filterbank( + num_bins: usize, + fft_size: usize, + sample_rate: f32, + low_freq: f32, + high_freq: f32, +) -> Vec> { + let half_fft = fft_size / 2 + 1; + + let hz_to_mel = |hz: f32| 1127.0 * (1.0 + hz / 700.0).ln(); + let mel_to_hz = |mel: f32| 700.0 * ((mel / 1127.0).exp() - 1.0); + + let low_mel = hz_to_mel(low_freq); + let high_mel = hz_to_mel(high_freq); + + let num_points = num_bins + 2; + let mel_points: Vec = (0..num_points) + .map(|i| low_mel + (high_mel - low_mel) * i as f32 / (num_points - 1) as f32) + .collect(); + let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); + let fft_bins: Vec = hz_points + .iter() + .map(|&hz| ((hz * fft_size as f32) / sample_rate).floor() as usize) + .collect(); + + let mut filterbank = vec![vec![0.0f32; half_fft]; num_bins]; + for (i, filter) in filterbank.iter_mut().enumerate() { + let left = fft_bins[i]; + let center = fft_bins[i + 1]; + let right = fft_bins[i + 2]; + + if center > left { + for j in left..center { + if j < half_fft { + filter[j] = (j - left) as f32 / (center - left) as f32; + } + } + } + if right > center { + for j in center..right { + if j < half_fft { + filter[j] = (right - j) as f32 / (right - center) as f32; + } + } + } + } + + filterbank +} +``` + +- [ ] **Step 2: Register in `src/features/mod.rs`** + +Add after existing exports: + +```rust +pub mod kaldi_fbank; +pub use kaldi_fbank::{compute_kaldi_fbank, KaldiFbankConfig}; +``` + +- [ ] **Step 3: Verify compilation** + +```bash +cargo check --features audio-features +``` + +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/features/kaldi_fbank.rs src/features/mod.rs +git commit -m "feat: add Kaldi-compatible fbank feature extraction" +``` + +--- + +### Task 3: Add BBPE decode module + +**Files:** +- Create: `src/decode/bbpe.rs` +- Modify: `src/decode/mod.rs` + +- [ ] **Step 1: Create `src/decode/bbpe.rs`** + +Port from `backup/pre-upstream-port:src/engines/zipformer_common.rs` (SymbolTable, BBPE mapping, normalize_text): + +```rust +//! BBPE (Byte-level BPE) symbol table for Icefall/sherpa-onnx models. +//! +//! Supports two encoding modes: +//! - BBPE: byte-to-unicode mapped tokens (Icefall zh-en models) +//! - BPE: standard sentencepiece tokens (literal UTF-8) +//! +//! Auto-detects encoding by checking for `bbpe.model` sibling file. + +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +/// Token encoding mode. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TokenEncoding { + /// Icefall BBPE: token chars are byte-to-unicode mapped, need decoding. + Bbpe, + /// Standard BPE/sentencepiece: token strings are literal UTF-8. + Bpe, +} + +/// Symbol table with BBPE/BPE decoding support. +pub struct BbpeSymbolTable { + id_to_sym: HashMap, + encoding: TokenEncoding, +} + +impl BbpeSymbolTable { + /// Load with auto-detected encoding. + /// If `bbpe.model` exists in the same directory as `path`, use BBPE; otherwise BPE. + pub fn load(path: &Path) -> Result { + let encoding = if let Some(dir) = path.parent() { + if dir.join("bbpe.model").exists() { + log::info!("Detected BBPE encoding (bbpe.model found)"); + TokenEncoding::Bbpe + } else { + log::info!("Detected standard BPE encoding (no bbpe.model)"); + TokenEncoding::Bpe + } + } else { + TokenEncoding::Bbpe + }; + Self::load_with_encoding(path, encoding) + } + + /// Load with explicit encoding. + pub fn load_with_encoding( + path: &Path, + encoding: TokenEncoding, + ) -> Result { + let contents = fs::read_to_string(path)?; + let mut id_to_sym = HashMap::new(); + + for line in contents.lines() { + let line = line.trim_end(); + if line.is_empty() { + continue; + } + // Format: "token id" (split on last whitespace; token can contain spaces) + let parts: Vec<&str> = line.rsplitn(2, |c: char| c.is_whitespace()).collect(); + if parts.len() == 2 { + if let Ok(id) = parts[0].parse::() { + id_to_sym.insert(id, parts[1].to_string()); + } + } + } + + log::info!( + "Loaded {} tokens from {:?} (encoding={:?})", + id_to_sym.len(), + path, + encoding + ); + Ok(Self { id_to_sym, encoding }) + } + + /// Decode token IDs to text. + pub fn decode(&self, token_ids: &[i32]) -> String { + match self.encoding { + TokenEncoding::Bbpe => self.decode_bbpe(token_ids), + TokenEncoding::Bpe => self.decode_bpe(token_ids), + } + } + + fn decode_bbpe(&self, token_ids: &[i32]) -> String { + let mut raw_bytes = Vec::new(); + + for &id in token_ids { + let Some(sym) = self.id_to_sym.get(&id) else { + continue; + }; + if sym.starts_with('<') && sym.ends_with('>') { + continue; + } + for c in sym.chars() { + if c == '\u{2581}' { + raw_bytes.push(b' '); + } else if let Some(byte_val) = bbpe_char_to_byte(c) { + raw_bytes.push(byte_val); + } + } + } + + let text = String::from_utf8_lossy(&raw_bytes); + normalize_text(text.trim()) + } + + fn decode_bpe(&self, token_ids: &[i32]) -> String { + let mut text = String::new(); + + for &id in token_ids { + let Some(sym) = self.id_to_sym.get(&id) else { + continue; + }; + if sym.starts_with('<') && sym.ends_with('>') { + continue; + } + text.push_str(&sym.replace('\u{2581}', " ")); + } + + normalize_text(text.trim()) + } +} + +// ---- Text normalization ---- + +fn is_cjk(c: char) -> bool { + matches!(c, + '\u{4E00}'..='\u{9FFF}' | + '\u{3400}'..='\u{4DBF}' | + '\u{F900}'..='\u{FAFF}' | + '\u{2E80}'..='\u{2EFF}' | + '\u{3000}'..='\u{303F}' | + '\u{FF00}'..='\u{FFEF}' + ) +} + +/// Remove spaces between CJK characters and lowercase English text. +fn normalize_text(text: &str) -> String { + let text = text.to_lowercase(); + let chars: Vec = text.chars().collect(); + let mut result = String::with_capacity(text.len()); + + for i in 0..chars.len() { + let c = chars[i]; + if c == ' ' { + let prev_cjk = i > 0 && is_cjk(chars[i - 1]); + let next_cjk = i + 1 < chars.len() && is_cjk(chars[i + 1]); + if prev_cjk && next_cjk { + continue; + } + } + result.push(c); + } + + result +} + +// ---- Icefall BBPE byte mapping ---- + +/// Icefall PRINTABLE_BASE_CHARS: maps byte index (0-255) to a Unicode codepoint. +const BBPE_CODEPOINTS: [u32; 256] = [ + 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, + 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, + 286, 287, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, + 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 288, 289, 290, 291, 292, + 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 308, 309, + 310, 311, 312, 313, 314, 315, 316, 317, 318, 321, 322, 323, 324, 325, 326, + 327, 328, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, + 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, + 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, + 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, + 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, + 419, 420, 421, 422, +]; + +fn bbpe_char_to_byte(c: char) -> Option { + let cp = c as u32; + if (32..=126).contains(&cp) { + return Some(cp as u8); + } + for (byte_val, &mapped_cp) in BBPE_CODEPOINTS.iter().enumerate() { + if mapped_cp == cp { + return Some(byte_val as u8); + } + } + None +} +``` + +- [ ] **Step 2: Register in `src/decode/mod.rs`** + +Add after existing exports: + +```rust +pub mod bbpe; +pub use bbpe::BbpeSymbolTable; +``` + +- [ ] **Step 3: Verify compilation** + +```bash +cargo check --features audio-features +``` + +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/decode/bbpe.rs src/decode/mod.rs +git commit -m "feat: add BBPE symbol table for Icefall/sherpa-onnx models" +``` + +--- + +### Task 4: Add ParaformerModel + +**Files:** +- Create: `src/onnx/paraformer/mod.rs` +- Modify: `src/onnx/mod.rs` + +- [ ] **Step 1: Create `src/onnx/paraformer/mod.rs`** + +Port from backup branch, adapting to upstream patterns. Key changes: +- Use `session::create_session()` instead of manual session builder +- Use `session::resolve_model_path()` for quantization +- Use `session::read_metadata_i32/float_vec()` for metadata +- Use upstream `features::compute_mel()` with `MelConfig` (Paraformer uses standard Hamming window fbank, NOT Kaldi fbank) +- Use upstream `features::apply_lfr()` and `features::apply_cmvn()` +- Return `TranscribeError` instead of custom error enum +- Implement `SpeechModel` trait +- Use ort rc.12 API (`session.inputs()`, `input.name()`, etc.) + +The complete file should include: +1. `CAPABILITIES` const +2. `ParaformerParams` struct (empty for now — Paraformer is language-auto) +3. `ParaformerModel` struct with session, symbol_table, metadata, cmvn, I/O names +4. `ParaformerModel::load(dir, &Quantization)` constructor +5. Paraformer-specific `SymbolTable` (inline, handles `@@` joining and `▁` markers — different from BBPE) +6. Metadata parsing via `session::read_metadata_i32` +7. CMVN loading from ONNX metadata or `am.mvn` file +8. `compute_features()` → `compute_mel()` + `apply_lfr()` + `apply_cmvn()` +9. `forward()` → run ONNX session +10. `decode_logits()` → argmax with eos/blank/sos filtering +11. `SpeechModel` impl with `transcribe_raw()` + +**Important Paraformer-specific details:** +- Paraformer uses dB scale fbank (10*log10), NOT natural log — use `MelConfig` with `pre_emphasis: None` and standard Hamming window, then manually apply 10*log10 scaling. Actually, looking at the old code more carefully: it uses `10.0 * sum.log10()` with `-80.0` floor. The upstream `compute_mel` with `pre_emphasis: None` uses `ln()`. We need to match the original behavior. +- Solution: Use upstream `compute_mel` with custom `MelConfig{pre_emphasis: None, ...}` — BUT upstream's `compute_mel_spectrogram` uses `ln()`, not `10*log10`. We need to either (a) modify the output, or (b) implement inline. Option (b) is safer to avoid breaking existing models. Implement a private `compute_paraformer_fbank()` inside the module that matches the original exactly. +- LFR default: window_size=7, window_shift=6 (from ONNX metadata) +- CMVN: mean subtraction only (old code uses `apply_mean_cmvn` which subtracts mean; upstream `apply_cmvn` multiplies by inv_stddev too). For Paraformer we only have neg_mean, no inv_stddev. So we do mean-only CMVN inline. +- Symbol table: Paraformer tokens use `@@` for subword joining and `▁` for spaces, plus special tokens ``, ``, ``, ``. This is different from both upstream's SymbolTable and the BBPE SymbolTable. Keep it inline in the module. + +- [ ] **Step 2: Register in `src/onnx/mod.rs`** + +Add: +```rust +pub mod paraformer; +``` + +- [ ] **Step 3: Verify compilation** + +```bash +cargo check --features onnx +``` + +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/onnx/paraformer/ src/onnx/mod.rs +git commit -m "feat: add Paraformer ONNX engine" +``` + +--- + +### Task 5: Add ZipformerCtcModel + +**Files:** +- Create: `src/onnx/zipformer_ctc/mod.rs` +- Modify: `src/onnx/mod.rs` + +- [ ] **Step 1: Create `src/onnx/zipformer_ctc/mod.rs`** + +Port from backup branch. Key adaptations: +- Use `session::create_session()` for session creation +- Use `compute_kaldi_fbank()` from `features::kaldi_fbank` +- Use upstream `ctc_greedy_decode()` from `decode::ctc` — BUT note: upstream CTC takes `ArrayView3` with shape [batch, time, vocab] and `&[i64]` lengths. Our old code had custom CTC with `Array2`. Need to reshape to 3D for upstream API. +- Use `BbpeSymbolTable` from `decode::bbpe` for token decoding +- Model file discovery: keep our smart fallback logic (scan directory for *.onnx) but also try `session::resolve_model_path()` first +- Streaming model rejection: keep the `cached_*` input detection +- Return `TranscribeError` +- Implement `SpeechModel` trait + +- [ ] **Step 2: Register in `src/onnx/mod.rs`** + +Add: +```rust +pub mod zipformer_ctc; +``` + +- [ ] **Step 3: Verify compilation** + +```bash +cargo check --features onnx +``` + +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/onnx/zipformer_ctc/ src/onnx/mod.rs +git commit -m "feat: add Zipformer CTC ONNX engine" +``` + +--- + +### Task 6: Add ZipformerTransducerModel + +**Files:** +- Create: `src/onnx/zipformer_transducer/mod.rs` +- Modify: `src/onnx/mod.rs` + +- [ ] **Step 1: Create `src/onnx/zipformer_transducer/mod.rs`** + +Port from backup branch. This is the most complex engine (3 sessions): +- Use `session::create_session()` for all 3 sessions +- Use `compute_kaldi_fbank()` for features +- Use `BbpeSymbolTable` for token decoding +- Keep the multi-file model discovery logic (`find_model_file` for encoder/decoder/joiner with various naming patterns) +- Keep streaming model rejection +- Keep the RNN-T greedy search decoding loop (no upstream equivalent) +- context_size=2 hardcoded +- Return `TranscribeError` +- Implement `SpeechModel` trait + +**Important:** The transducer's `find_model_file` looks for `{component}-*.{suffix}.onnx` patterns (e.g., `encoder-epoch-34-avg-19.int8.onnx`). This is unique to sherpa-onnx transducer models and must be preserved. + +- [ ] **Step 2: Register in `src/onnx/mod.rs`** + +Add: +```rust +pub mod zipformer_transducer; +``` + +- [ ] **Step 3: Verify compilation** + +```bash +cargo check --features onnx +``` + +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add src/onnx/zipformer_transducer/ src/onnx/mod.rs +git commit -m "feat: add Zipformer Transducer ONNX engine" +``` + +--- + +### Task 7: Add PunctModel + +**Files:** +- Create: `src/punct.rs` +- Modify: `src/lib.rs` +- Modify: `Cargo.toml` + +- [ ] **Step 1: Add `punct` feature to `Cargo.toml`** + +In `[features]` section, add: +```toml +# Neural punctuation restoration (CT-Transformer) +punct = ["dep:ort", "dep:ndarray"] +``` + +Update `all` feature to include `punct`: +```toml +all = ["onnx", "whisper-cpp", "whisperfile", "openai", "punct"] +``` + +- [ ] **Step 2: Create `src/punct.rs`** + +Port from backup branch with these adaptations: +- Use `session::create_session()` instead of manual session builder (but note: punct uses `#[cfg(feature = "punct")]` not `#[cfg(feature = "onnx")]`, and `session` module is under `onnx` feature. So we need to build the session manually for punct, OR gate punct under onnx.) +- **Decision:** Gate punct session creation manually (like `vad-silero` does — it also uses ort directly without the onnx feature). Use `ort` directly: + +```rust +use ort::session::builder::GraphOptimizationLevel; +use ort::session::Session; +``` + +- Keep the custom `PunctError` enum (it's not a SpeechModel, so TranscribeError doesn't fit perfectly. But for consistency, convert to `TranscribeError`.) +- **Decision:** Use `TranscribeError` for consistency with the rest of the crate. The From impls already exist for ort::Error, serde_json::Error, and io::Error. +- Use ort rc.12 API for session inputs/outputs +- Keep all inference logic unchanged + +- [ ] **Step 3: Register in `src/lib.rs`** + +Add after existing module declarations: +```rust +#[cfg(feature = "punct")] +pub mod punct; +``` + +- [ ] **Step 4: Verify compilation** + +```bash +cargo check --features punct +cargo check --features "onnx,punct" +``` + +Expected: both PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/punct.rs src/lib.rs Cargo.toml +git commit -m "feat: add neural punctuation restoration model" +``` + +--- + +### Task 8: Add examples + +**Files:** +- Create: `examples/paraformer.rs` +- Create: `examples/zipformer_ctc.rs` +- Create: `examples/zipformer_transducer.rs` +- Modify: `Cargo.toml` + +- [ ] **Step 1: Create examples** + +Follow the upstream pattern from `examples/gigaam.rs`. Each example: +- Accepts model_dir and wav_path as positional args with defaults +- Supports `--int8` flag +- Shows load time, transcribe time, real-time speedup +- Displays text and segments + +Default model paths: +- Paraformer: `models/sherpa-onnx-paraformer-zh-2025-10-07` +- Zipformer CTC: `models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16` +- Zipformer Transducer: `models/sherpa-onnx-zipformer-zh-en-2023-11-22` + +Default wav: `samples/zh.wav` + +- [ ] **Step 2: Add example declarations to `Cargo.toml`** + +```toml +[[example]] +name = "paraformer" +required-features = ["onnx"] + +[[example]] +name = "zipformer_ctc" +required-features = ["onnx"] + +[[example]] +name = "zipformer_transducer" +required-features = ["onnx"] +``` + +- [ ] **Step 3: Verify examples compile** + +```bash +cargo check --example paraformer --features onnx +cargo check --example zipformer_ctc --features onnx +cargo check --example zipformer_transducer --features onnx +``` + +Expected: all PASS + +- [ ] **Step 4: Commit** + +```bash +git add examples/paraformer.rs examples/zipformer_ctc.rs examples/zipformer_transducer.rs Cargo.toml +git commit -m "feat: add examples for Paraformer and Zipformer engines" +``` + +--- + +### Task 9: Add tests + +**Files:** +- Create: `tests/paraformer.rs` +- Create: `tests/zipformer_ctc.rs` +- Create: `tests/zipformer_transducer.rs` +- Modify: `Cargo.toml` + +- [ ] **Step 1: Create test files** + +Follow upstream pattern from `tests/gigaam.rs`. Each test: +- Uses `mod common;` for `require_paths` +- Skips if model/wav not found (graceful skip, not failure) +- Loads model with `Quantization::Int8` +- Transcribes a test WAV +- Asserts expected output text + +- [ ] **Step 2: Add test declarations to `Cargo.toml`** + +```toml +[[test]] +name = "paraformer" +required-features = ["onnx"] + +[[test]] +name = "zipformer_ctc" +required-features = ["onnx"] + +[[test]] +name = "zipformer_transducer" +required-features = ["onnx"] +``` + +- [ ] **Step 3: Verify tests compile** + +```bash +cargo test --no-run --features onnx +``` + +Expected: PASS (tests compile; may skip at runtime if models not present) + +- [ ] **Step 4: Commit** + +```bash +git add tests/paraformer.rs tests/zipformer_ctc.rs tests/zipformer_transducer.rs Cargo.toml +git commit -m "test: add tests for Paraformer and Zipformer engines" +``` + +--- + +### Task 10: Full verification + +- [ ] **Step 1: Verify all features compile** + +```bash +cargo check --features onnx +cargo check --features punct +cargo check --features "onnx,punct" +cargo check --features all +``` + +Expected: all PASS + +- [ ] **Step 2: Run cargo clippy** + +```bash +cargo clippy --features "onnx,punct" -- -D warnings +``` + +Expected: PASS (no warnings) + +- [ ] **Step 3: Run cargo fmt** + +```bash +cargo fmt --check +``` + +Expected: PASS + +- [ ] **Step 4: Run tests with models (if available)** + +```bash +cargo test --features onnx -- --nocapture +``` + +- [ ] **Step 5: Run examples with models (if available)** + +```bash +cargo run --example paraformer --features onnx -- models/sherpa-onnx-paraformer-zh-2025-10-07 samples/zh.wav --int8 +cargo run --example zipformer_ctc --features onnx -- models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16 samples/zh.wav --int8 +cargo run --example zipformer_transducer --features onnx -- models/sherpa-onnx-zipformer-zh-en-2023-11-22 samples/zh.wav --int8 +``` + +- [ ] **Step 6: Final commit (if any fmt/clippy fixes)** + +```bash +git add -A +git commit -m "chore: fix clippy warnings and formatting" +``` diff --git a/examples/llm_postprocess.rs b/examples/llm_postprocess.rs new file mode 100644 index 0000000..a875c0e --- /dev/null +++ b/examples/llm_postprocess.rs @@ -0,0 +1,89 @@ +//! LLM-based ASR post-processing example using Qwen2.5-0.5B (GGUF). +//! +//! Demonstrates loading a quantized Qwen2.5 model via the library and using it +//! to add punctuation and correct errors in ASR output. +//! +//! Usage: +//! cargo run --example llm_postprocess --features llm-postprocess --release +//! cargo run --example llm_postprocess --features llm-postprocess --release -- \ +//! --model models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf \ +//! --tokenizer models/qwen2.5-0.5b/tokenizer.json \ +//! --text "今天天气很好我们去公圆玩吧他说号的" + +use std::io::Write; +use std::path::Path; +use std::time::Instant; + +use transcribe_rs::llm_postprocess::LlmPostProcessor; + +const DEFAULT_MODEL: &str = "models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf"; +const DEFAULT_TOKENIZER: &str = "models/qwen2.5-0.5b/tokenizer.json"; +const DEFAULT_TEXT: &str = "今天天气很好我们去公圆玩吧他说号的"; + +fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + let (model_path, tokenizer_path, input_text) = parse_args(&args); + + println!("=== LLM 后处理验证 (Qwen2.5-0.5B GGUF) ===\n"); + + // 1. Load model + print!("加载模型... "); + std::io::stdout().flush()?; + let load_start = Instant::now(); + + let mut processor = + LlmPostProcessor::from_files(Path::new(&model_path), Path::new(&tokenizer_path))?; + + let load_time = load_start.elapsed(); + println!("完成 ({:.2?})", load_time); + + // 2. Process text + println!(); + println!("原始文本: {}", input_text); + + let gen_start = Instant::now(); + let result = processor.process(&input_text)?; + let gen_time = gen_start.elapsed(); + + println!("修正文本: {}", result); + println!(); + println!("--- 统计 ---"); + println!("耗时: {:.2?}", gen_time); + println!("模型加载: {:.2?}", load_time); + + Ok(()) +} + +fn parse_args(args: &[String]) -> (String, String, String) { + let mut model = DEFAULT_MODEL.to_string(); + let mut tokenizer = DEFAULT_TOKENIZER.to_string(); + let mut text = DEFAULT_TEXT.to_string(); + + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--model" => { + i += 1; + if i < args.len() { + model = args[i].clone(); + } + } + "--tokenizer" => { + i += 1; + if i < args.len() { + tokenizer = args[i].clone(); + } + } + "--text" => { + i += 1; + if i < args.len() { + text = args[i].clone(); + } + } + _ => {} + } + i += 1; + } + + (model, tokenizer, text) +} diff --git a/src/decode/bbpe.rs b/src/decode/bbpe.rs index cf3844d..29d667f 100644 --- a/src/decode/bbpe.rs +++ b/src/decode/bbpe.rs @@ -33,10 +33,10 @@ impl BbpeSymbolTable { pub fn load_autodetect(path: &Path) -> Result { let encoding = if let Some(dir) = path.parent() { if dir.join("bbpe.model").exists() { - log::info!("Detected BBPE encoding (bbpe.model found)"); + log::debug!("Detected BBPE encoding (bbpe.model found)"); TokenEncoding::Bbpe } else { - log::info!("Detected standard BPE encoding (no bbpe.model)"); + log::debug!("Detected standard BPE encoding (no bbpe.model)"); TokenEncoding::Bpe } } else { diff --git a/src/llm_postprocess.rs b/src/llm_postprocess.rs new file mode 100644 index 0000000..cfd7bd9 --- /dev/null +++ b/src/llm_postprocess.rs @@ -0,0 +1,273 @@ +//! LLM-based ASR post-processing using quantized Qwen2.5 (GGUF). +//! +//! Uses a quantized Qwen2.5-0.5B model via candle to add punctuation and +//! correct homophones in ASR output. All inference runs on CPU with no +//! dynamic library dependencies. +//! +//! # Feature gate +//! +//! This module requires the `llm-postprocess` feature: +//! +//! ```toml +//! [dependencies] +//! transcribe-rs = { version = "0.2", features = ["llm-postprocess"] } +//! ``` +//! +//! # Model files +//! +//! You need a GGUF-quantized Qwen2.5 model and its tokenizer: +//! +//! - `qwen2.5-0.5b-instruct-q4_k_m.gguf` (~350 MB) +//! - `tokenizer.json` (from HuggingFace Qwen2.5-0.5B-Instruct) +//! +//! Place them in a single directory (e.g. `models/qwen2.5-0.5b/`). +//! +//! # Usage +//! +//! **Reusable processor** (recommended for multiple calls): +//! +//! ```ignore +//! use std::path::Path; +//! use transcribe_rs::llm_postprocess::LlmPostProcessor; +//! +//! let mut proc = LlmPostProcessor::new(Path::new("models/qwen2.5-0.5b/"))?; +//! +//! let corrected = proc.process("今天天气很好我们去公圆玩吧他说号的")?; +//! println!("{}", corrected); +//! // => "今天天气很好,我们去公园玩吧,他说好的。" +//! ``` +//! +//! **One-shot convenience function**: +//! +//! ```ignore +//! use std::path::Path; +//! use transcribe_rs::llm_postprocess::llm_postprocess; +//! +//! let corrected = llm_postprocess( +//! "今天天气很好我们去公圆玩吧", +//! Path::new("models/qwen2.5-0.5b/"), +//! )?; +//! ``` +//! +//! **Custom system prompt**: +//! +//! ```ignore +//! let corrected = proc.process_with_prompt( +//! "the wether is grate today", +//! "You are a post-processing assistant. Fix punctuation and spelling.", +//! )?; +//! ``` +//! +//! # Pipeline integration +//! +//! Typical ASR post-processing pipeline: +//! +//! 1. **CT-Transformer** (`punct` feature) — fast punctuation restoration (~7 ms) +//! 2. **LLM post-process** (`llm-postprocess` feature) — deep correction (~1-3 s) +//! +//! ```ignore +//! // Step 1: fast punctuation +//! let mut punct = transcribe_rs::PunctModel::new(Path::new("models/punct/"))?; +//! let text = punct.add_punctuation(&raw_asr_text); +//! +//! // Step 2: LLM correction (optional, slower but more accurate) +//! let text = proc.process(&text)?; +//! ``` + +use std::path::Path; + +use candle_core::quantized::gguf_file; +use candle_core::{Device, Tensor}; +use candle_transformers::models::quantized_qwen2::ModelWeights; +use tokenizers::Tokenizer; + +const MAX_TOKENS: usize = 256; +const EOS_TOKEN: &str = "<|im_end|>"; +const DEFAULT_EOS_ID: u32 = 151645; + +const DEFAULT_SYSTEM_PROMPT: &str = "你是语音识别后处理助手。用户输入是语音识别的原始输出,\ +可能缺少标点、含有同音错别字。请添加正确的标点符号,并将同音错别字纠正为正确的字词。\ +只输出纠正后的完整文本。"; + +#[derive(thiserror::Error, Debug)] +pub enum LlmPostProcessError { + #[error("candle error: {0}")] + Candle(#[from] candle_core::Error), + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("tokenizer error: {0}")] + Tokenizer(String), + #[error("model file not found: {0}")] + ModelNotFound(String), + #[error("tokenizer file not found: {0}")] + TokenizerNotFound(String), + #[error("no GGUF file found in directory: {0}")] + NoGgufFile(String), +} + +/// LLM-based post-processor holding a quantized Qwen2.5 model and tokenizer. +/// +/// Reuse a single instance for multiple calls to avoid repeated model loading. +pub struct LlmPostProcessor { + model: ModelWeights, + tokenizer: Tokenizer, + device: Device, + eos_token_id: u32, +} + +impl LlmPostProcessor { + /// Load from a model directory containing `*.gguf` and `tokenizer.json`. + pub fn new(model_dir: &Path) -> Result { + let gguf_path = find_gguf_file(model_dir)?; + let tokenizer_path = model_dir.join("tokenizer.json"); + if !tokenizer_path.exists() { + return Err(LlmPostProcessError::TokenizerNotFound( + tokenizer_path.display().to_string(), + )); + } + Self::from_files(&gguf_path, &tokenizer_path) + } + + /// Load from explicit file paths. + pub fn from_files( + gguf_path: &Path, + tokenizer_path: &Path, + ) -> Result { + if !gguf_path.exists() { + return Err(LlmPostProcessError::ModelNotFound( + gguf_path.display().to_string(), + )); + } + if !tokenizer_path.exists() { + return Err(LlmPostProcessError::TokenizerNotFound( + tokenizer_path.display().to_string(), + )); + } + + let device = Device::Cpu; + + let tokenizer = Tokenizer::from_file(tokenizer_path) + .map_err(|e| LlmPostProcessError::Tokenizer(e.to_string()))?; + + let mut file = std::fs::File::open(gguf_path)?; + let content = gguf_file::Content::read(&mut file)?; + let model = ModelWeights::from_gguf(content, &mut file, &device)?; + + let eos_token_id = tokenizer.token_to_id(EOS_TOKEN).unwrap_or(DEFAULT_EOS_ID); + + Ok(Self { + model, + tokenizer, + device, + eos_token_id, + }) + } + + /// Process ASR text using the default system prompt. + pub fn process(&mut self, text: &str) -> Result { + self.process_with_prompt(text, DEFAULT_SYSTEM_PROMPT) + } + + /// Process ASR text using a custom system prompt. + pub fn process_with_prompt( + &mut self, + text: &str, + system_prompt: &str, + ) -> Result { + let prompt = format!( + "<|im_start|>system\n{system_prompt}<|im_end|>\n\ + <|im_start|>user\n\ + 请纠正以下语音识别文本中的标点和错别字:\n\ + {text}<|im_end|>\n\ + <|im_start|>assistant\n" + ); + + let encoding = self + .tokenizer + .encode(prompt.as_str(), true) + .map_err(|e| LlmPostProcessError::Tokenizer(e.to_string()))?; + let prompt_tokens = encoding.get_ids().to_vec(); + let prompt_len = prompt_tokens.len(); + + // Feed prompt through the model + let input = Tensor::new(prompt_tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, 0)?; + let last_logits = extract_last_logits(&logits)?; + let mut next_token = sample_greedy(&last_logits)?; + + let mut output_text = String::new(); + let mut generated_tokens: usize = 0; + + for _ in 0..MAX_TOKENS { + if next_token == self.eos_token_id { + break; + } + + generated_tokens += 1; + + if let Ok(decoded) = self.tokenizer.decode(&[next_token], false) { + output_text.push_str(&decoded); + } + + // Forward pass for next token + let input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; + let pos = prompt_len + generated_tokens - 1; + let logits = self.model.forward(&input, pos)?; + let last_logits = extract_last_logits(&logits)?; + next_token = sample_greedy(&last_logits)?; + } + + Ok(output_text) + } +} + +/// Convenience function that loads the model and processes text in one call. +/// +/// For repeated use, prefer creating an [`LlmPostProcessor`] instance directly. +pub fn llm_postprocess(text: &str, model_dir: &Path) -> Result { + let mut processor = LlmPostProcessor::new(model_dir)?; + processor.process(text) +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +/// Find the first `*.gguf` file in a directory. +fn find_gguf_file(dir: &Path) -> Result { + if dir.is_file() && dir.extension().is_some_and(|e| e == "gguf") { + return Ok(dir.to_path_buf()); + } + + let entries = std::fs::read_dir(dir)?; + for entry in entries { + let entry = entry?; + let path = entry.path(); + if path.extension().is_some_and(|e| e == "gguf") { + return Ok(path); + } + } + + Err(LlmPostProcessError::NoGgufFile(dir.display().to_string())) +} + +/// Extract the last position's logits, handling 1D/2D/3D tensor shapes. +fn extract_last_logits(logits: &Tensor) -> Result { + match logits.dims().len() { + 3 => { + let logits = logits.squeeze(0)?; + logits.get(logits.dim(0)? - 1) + } + 2 => logits.get(logits.dim(0)? - 1), + 1 => Ok(logits.clone()), + _ => Err(candle_core::Error::Msg(format!( + "unexpected logits shape: {:?}", + logits.dims() + ))), + } +} + +/// Greedy (argmax) token sampling. +fn sample_greedy(logits: &Tensor) -> Result { + logits.argmax(0)?.to_scalar::() +} diff --git a/src/onnx/session.rs b/src/onnx/session.rs index 676c6da..9a8ff4c 100644 --- a/src/onnx/session.rs +++ b/src/onnx/session.rs @@ -129,14 +129,14 @@ fn build_session( .commit_from_file(path)?; for input in session.inputs() { - log::info!( + log::debug!( "Model input: name={}, type={:?}", input.name(), input.dtype() ); } for output in session.outputs() { - log::info!( + log::debug!( "Model output: name={}, type={:?}", output.name(), output.dtype() diff --git a/src/onnx/zipformer_ctc/mod.rs b/src/onnx/zipformer_ctc/mod.rs index 40058d4..57b812f 100644 --- a/src/onnx/zipformer_ctc/mod.rs +++ b/src/onnx/zipformer_ctc/mod.rs @@ -200,13 +200,13 @@ impl ZipformerCtcModel { if prefer_int8 { if let Some(p) = int8_candidates.into_iter().next() { - log::info!("Found int8 model by directory scan: {:?}", p); + log::debug!("Found int8 model by directory scan: {:?}", p); return Ok(p); } } if let Some(p) = fp32_candidates.into_iter().next() { - log::info!("Found model by directory scan: {:?}", p); + log::debug!("Found model by directory scan: {:?}", p); return Ok(p); } diff --git a/src/onnx/zipformer_transducer/mod.rs b/src/onnx/zipformer_transducer/mod.rs index 9eef497..89838b7 100644 --- a/src/onnx/zipformer_transducer/mod.rs +++ b/src/onnx/zipformer_transducer/mod.rs @@ -331,7 +331,7 @@ impl ZipformerTransducerModel { // 3. Prefer suffixed match if let Some(p) = suffixed_candidates.into_iter().next() { - log::info!( + log::debug!( "Found {} model by directory scan: {:?}", component, p.file_name().unwrap_or_default() @@ -341,7 +341,7 @@ impl ZipformerTransducerModel { // 4. Fall back to any .onnx match if let Some(p) = plain_candidates.into_iter().next() { - log::info!( + log::debug!( "Found {} model by directory scan: {:?}", component, p.file_name().unwrap_or_default() diff --git a/src/punct.rs b/src/punct.rs index 35ee7cb..f273755 100644 --- a/src/punct.rs +++ b/src/punct.rs @@ -149,7 +149,7 @@ impl PunctModel { let input_name = session.inputs()[0].name().to_string(); let length_name = session.inputs()[1].name().to_string(); - log::info!( + log::debug!( "Punct model input names: '{}' and '{}'", input_name, length_name @@ -172,7 +172,7 @@ impl PunctModel { token2id.insert(token.clone(), id as i32); } let unk_id = *token2id.get("").unwrap_or(&0); - log::info!("Loaded {} tokens, unk_id={}", tokens.len(), unk_id); + log::debug!("Loaded {} tokens, unk_id={}", tokens.len(), unk_id); Ok((token2id, unk_id)) } From bfab8ceaa9d998760da19a796815ce3f846e8116 Mon Sep 17 00:00:00 2001 From: Pengfei Date: Mon, 30 Mar 2026 08:57:37 +0800 Subject: [PATCH 12/12] chore: remove accidentally committed files and fix clippy - Remove llm_postprocess module (not yet ported, broke example build) - Remove stale docs and plan files - Fix clippy skip(0) warning in punct.rs Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/llm-postprocess-quickstart.md | 90 -- docs/post-processing-analysis.md | 122 --- .../plans/2026-03-29-port-sherpa-engines.md | 942 ------------------ examples/llm_postprocess.rs | 89 -- src/llm_postprocess.rs | 273 ----- src/punct.rs | 1 - 6 files changed, 1517 deletions(-) delete mode 100644 docs/llm-postprocess-quickstart.md delete mode 100644 docs/post-processing-analysis.md delete mode 100644 docs/superpowers/plans/2026-03-29-port-sherpa-engines.md delete mode 100644 examples/llm_postprocess.rs delete mode 100644 src/llm_postprocess.rs diff --git a/docs/llm-postprocess-quickstart.md b/docs/llm-postprocess-quickstart.md deleted file mode 100644 index e7dfb5d..0000000 --- a/docs/llm-postprocess-quickstart.md +++ /dev/null @@ -1,90 +0,0 @@ -# LLM 后处理快速测试指南 - -## 环境准备 - -### 1. 下载模型文件 - -```bash -mkdir -p models/qwen2.5-0.5b - -# 下载 GGUF 量化模型(约 350MB) -# 从 HuggingFace 下载 Qwen2.5-0.5B-Instruct 的 GGUF 版本 -# 推荐 q4_k_m 量化(质量与大小平衡) -wget -O models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf \ - "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/qwen2.5-0.5b-instruct-q4_k_m.gguf" - -# 下载 tokenizer.json(GGUF 不含 tokenizer,需单独下载) -wget -O models/qwen2.5-0.5b/tokenizer.json \ - "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" -``` - -### 2. 验证文件 - -```bash -ls -lh models/qwen2.5-0.5b/ -# 预期: -# qwen2.5-0.5b-instruct-q4_k_m.gguf ~350MB -# tokenizer.json ~7MB -``` - -## 依赖说明 - -| 依赖 | 版本 | 用途 | -|------|------|------| -| candle-core | 0.9.2 | 张量运算 | -| candle-nn | 0.9.2 | 神经网络层 | -| candle-transformers | 0.9.2 | Qwen2 模型架构 + GGUF 加载 | -| tokenizers | 0.22 | HuggingFace tokenizer 加载 | - -所有依赖均为纯 Rust(CPU 路径),无需安装任何 C/C++ 库。 - -## 编译 - -```bash -# 编译 example(首次编译 candle 约 2-3 分钟) -cargo build --example llm_postprocess --features llm-postprocess -``` - -## 运行 - -```bash -# 使用默认测试文本 -cargo run --example llm_postprocess --features llm-postprocess - -# 自定义输入 -cargo run --example llm_postprocess --features llm-postprocess -- \ - --model models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf \ - --tokenizer models/qwen2.5-0.5b/tokenizer.json \ - --text "今天天气很好我们去公圆玩吧他说号的" -``` - -### 预期输出 - -``` -原始文本: 今天天气很好我们去公圆玩吧他说号的 -修正文本: 今天天气很好,我们去公园玩吧。他说好的。 -生成 tokens: 25 -耗时: 2.34s -速度: 10.7 tok/s -``` - -> 性能因硬件而异。Apple Silicon (M1/M2) CPU 约 30-60 tok/s,x86 可能更慢。 - -## 已知限制 - -1. **仅 CPU 推理**:当前未启用 Metal/CUDA 加速,纯 CPU 运行 -2. **GGUF tokenizer 缺失**:candle 从 GGUF 文件只读取权重,tokenizer 需单独提供 `tokenizer.json` -3. **首次加载较慢**:模型加载约 1-2s(磁盘 I/O),后续推理约 2-5s/句 -4. **中文为主**:Qwen2.5 对中文纠错效果好,英文效果取决于上下文 -5. **非流式**:当前实现为完整生成后输出,非逐 token 流式 -6. **max_tokens 固定**:默认最大生成 256 tokens,超长文本需分句处理 - -## 故障排查 - -| 问题 | 解决方案 | -|------|---------| -| 编译错误 `candle-core not found` | 确认使用 `--features llm-postprocess` | -| 运行时 `model file not found` | 检查模型路径是否正确 | -| 运行时 `tokenizer.json not found` | 需单独下载 tokenizer,见上方步骤 | -| 输出乱码 | 检查 GGUF 文件是否完整下载 | -| 内存不足 | q4_0 需 ~280MB RAM,q4_k_m 需 ~350MB | diff --git a/docs/post-processing-analysis.md b/docs/post-processing-analysis.md deleted file mode 100644 index 3f25108..0000000 --- a/docs/post-processing-analysis.md +++ /dev/null @@ -1,122 +0,0 @@ -# ASR 后处理增强分析 - -## 1. 当前架构 - -transcribe-rs 使用 **ort (ONNX Runtime)** 静态编译,支持多种 ASR 引擎(Paraformer、SenseVoice、Zipformer 等), -后处理目前仅有 CT-Transformer 标点恢复(`punct.rs`)。 - -| 维度 | ort + ONNX(当前) | Vosk(已弃用) | -|------|-------------------|---------------| -| 链接方式 | Rust 静态编译,零动态库 | 需要 libvosk.dylib/so | -| 部署复杂度 | 单二进制 + 模型文件 | 二进制 + 动态库 + 模型 | -| 跨平台 | macOS/Linux/Windows 统一 | 每个平台单独编译动态库 | -| 推理性能 | ONNX 优化图,CPU 7ms/句 | 相当 | - -**结论**:ort 静态编译路线已被验证,后续扩展应保持零动态库依赖。 - ---- - -## 2. 后处理方案对比 - -### 2.1 专用标点模型 - -| 模型 | 语言 | 大小 | 延迟 (CPU) | 能力 | -|------|------|------|-----------|------| -| CT-Transformer(当前) | 中英 | ~50MB (int8) | ~7ms/句 | 标点恢复 | -| punct_cap_seg_47lang | 47语言 | ~500MB | ~20ms/句 | 标点 + 大小写 + 分句 | - -### 2.2 小型 LLM - -| 模型 | 参数量 | 量化大小 | 延迟 (CPU) | 能力 | -|------|--------|---------|-----------|------| -| Qwen2.5-0.5B-Instruct | 0.5B | 280MB (q4_0) / 350MB (q4_k_m) | 2-5s/句 | 标点 + 纠错 + 语义修正 | -| Phi-4-mini | 3.8B | ~2.2GB (q4) | 10-30s/句 | 更强纠错,但太慢太大 | - -### 2.3 对比结论 - -| 维度 | CT-Transformer | Qwen2.5-0.5B | -|------|---------------|---------------| -| 速度 | 极快 (~7ms) | 较慢 (~2-5s) | -| 内存 | ~100MB | ~280-350MB | -| 能力 | 仅标点 | 标点 + 纠错 + 语义 | -| 质量 | 标点准确率高 | 可纠正同音字、语法错误 | -| 适用场景 | 实时/批量 | 离线/高质量后处理 | - ---- - -## 3. Rust 推理框架对比 - -| 框架 | 语言 | 链接方式 | GGUF 支持 | 成熟度 | -|------|------|---------|----------|--------| -| **candle** (0.9.2) | 纯 Rust | 静态,零 C 依赖 | 原生支持 | HuggingFace 官方 | -| llama-cpp-rs | Rust bindings → C++ | 静态链接 llama.cpp | 原生 | 成熟但引入 C++ | -| ort(现有) | Rust bindings → ONNX Runtime | 静态链接 | 不支持 GGUF | 项目已在用 | - -**选择 candle 的理由**: -- 纯 Rust,与项目零动态库策略一致 -- 原生支持 GGUF 量化格式(`candle_transformers::quantized`) -- 内置 Qwen2 架构支持(`quantized_qwen2::ModelWeights`) -- HuggingFace 官方维护,API 稳定 -- 与现有 ort 依赖无冲突(candle CPU 路径纯 Rust) - ---- - -## 4. 推荐架构:两层管线 - -``` -ASR 原始文本 - │ - ▼ -┌─────────────────────┐ -│ 第一层:CT-Transformer │ ~7ms,始终运行 -│ (标点恢复) │ -└─────────┬───────────┘ - │ - ▼ -┌─────────────────────┐ -│ 第二层:Qwen2.5-0.5B │ ~2-5s,可选 -│ (纠错 + 语义修正) │ -└─────────┬───────────┘ - │ - ▼ - 最终输出文本 -``` - -### 设计原则 - -1. **第一层始终运行**:CT-Transformer 速度极快,开销可忽略 -2. **第二层可选启用**:通过 feature flag `llm-postprocess` 控制编译 -3. **独立依赖**:candle 依赖与 ort 依赖完全隔离,互不影响 -4. **渐进式增强**:不修改现有 `TranscriptionEngine` trait 和 `punct.rs` - ---- - -## 5. 延迟 / 内存 / 大小对比 - -| 指标 | CT-Transformer | Qwen2.5-0.5B (q4_0) | Qwen2.5-0.5B (q4_k_m) | 两层合计 | -|------|---------------|---------------------|----------------------|---------| -| 模型文件 | ~50MB | ~280MB | ~350MB | ~330-400MB | -| 运行内存 | ~100MB | ~280MB | ~350MB | ~380-450MB | -| 推理延迟 | ~7ms | ~2-3s | ~2-5s | ~2-5s | -| 吞吐量 | ~30-60 tok/s | ~30-60 tok/s (M1) | ~30-60 tok/s (M1) | N/A | -| 编译产物增量 | 已含 | +5-8MB | +5-8MB | +5-8MB | - -> 以上数据基于 Apple Silicon (M1/M2) CPU 推理,x86 平台可能略慢。 - ---- - -## 6. 实施路线 - -### Phase 1(当前) -- 创建独立 example 验证 candle + Qwen2.5-0.5B 可行性 -- 不修改 lib.rs,不影响现有功能 - -### Phase 2(后续) -- 将 LLM 后处理封装为 `llm_postprocess.rs` 模块 -- 集成到 `TranscriptionEngine` trait 的后处理管线 -- 支持通过配置切换是否启用 - -### Phase 3(远期) -- 探索 candle Metal/CUDA 加速 -- 评估 Qwen2.5-1.5B 或更大模型的效果 -- 考虑流式后处理(逐句修正) diff --git a/docs/superpowers/plans/2026-03-29-port-sherpa-engines.md b/docs/superpowers/plans/2026-03-29-port-sherpa-engines.md deleted file mode 100644 index 140dac4..0000000 --- a/docs/superpowers/plans/2026-03-29-port-sherpa-engines.md +++ /dev/null @@ -1,942 +0,0 @@ -# Port Paraformer/Zipformer Engines to Upstream Architecture - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Port Paraformer, Zipformer CTC, Zipformer Transducer engines and punctuation model from our fork's old `TranscriptionEngine` trait to upstream's `SpeechModel` trait architecture, enabling PR merge into cjpais/transcribe-rs. - -**Architecture:** Create a feature branch from `upstream/main` (v0.3.5). Add shared Kaldi fbank and BBPE decode modules. Port each engine as a new module under `src/onnx/`. Port punct.rs as a standalone feature. All engines use upstream's `session.rs`, `TranscribeError`, and `Quantization`. - -**Tech Stack:** Rust, ort 2.0.0-rc.12, ndarray 0.17, rustfft 6, serde/serde_json - ---- - -## File Structure - -### New files to create: -- `src/features/kaldi_fbank.rs` — Kaldi-compatible fbank (Povey window, DC removal, preemphasis, neg high_freq) -- `src/decode/bbpe.rs` — BBPE byte-to-unicode symbol table + text normalization -- `src/onnx/paraformer/mod.rs` — ParaformerModel + SpeechModel impl -- `src/onnx/zipformer_ctc/mod.rs` — ZipformerCtcModel + SpeechModel impl -- `src/onnx/zipformer_transducer/mod.rs` — ZipformerTransducerModel + SpeechModel impl -- `src/punct.rs` — PunctModel (standalone punctuation post-processor) -- `examples/paraformer.rs` — Paraformer example -- `examples/zipformer_ctc.rs` — Zipformer CTC example -- `examples/zipformer_transducer.rs` — Zipformer Transducer example -- `tests/paraformer.rs` — Paraformer test -- `tests/zipformer_ctc.rs` — Zipformer CTC test -- `tests/zipformer_transducer.rs` — Zipformer Transducer test - -### Files to modify: -- `src/features/mod.rs` — add `pub mod kaldi_fbank;` -- `src/decode/mod.rs` — add `pub mod bbpe;` -- `src/onnx/mod.rs` — add `pub mod paraformer; pub mod zipformer_ctc; pub mod zipformer_transducer;` -- `src/lib.rs` — add `pub mod punct;` under punct feature -- `src/error.rs` — no changes needed (already has ort::Error and serde_json::Error From impls) -- `Cargo.toml` — add `punct` feature, examples, tests - -### Reference files (read from backup branch): -- `backup/pre-upstream-port:src/engines/paraformer/model.rs` — Paraformer inference logic -- `backup/pre-upstream-port:src/engines/paraformer/features.rs` — Paraformer fbank (simpler, non-Kaldi) -- `backup/pre-upstream-port:src/engines/paraformer/tokens.rs` — Paraformer symbol table -- `backup/pre-upstream-port:src/engines/zipformer_common.rs` — Kaldi fbank + BBPE + SymbolTable -- `backup/pre-upstream-port:src/engines/zipformer_ctc/model.rs` — CTC inference -- `backup/pre-upstream-port:src/engines/zipformer_transducer/model.rs` — Transducer inference -- `backup/pre-upstream-port:src/punct.rs` — Punctuation model - ---- - -## Key API Mapping (old → new) - -### ort rc.10 → rc.12 -- `session.inputs` → `session.inputs()` -- `session.outputs` → `session.outputs()` -- `input.name` → `input.name()` -- `input.input_type` → `input.dtype()` -- `output.name` → `output.name()` -- `output.output_type` → `output.dtype()` -- `CPUExecutionProvider::default().build()` → use `session::create_session()` (handles all EPs) -- `metadata.custom(key)?` returns `Option` (no Result wrapping in rc.12; use `session::read_metadata_str`) - -### Trait mapping -- `TranscriptionEngine::transcribe_samples(samples, params)` → `SpeechModel::transcribe_raw(samples, &TranscribeOptions)` -- `Box` → `TranscribeError` -- `ParaformerModel::new(dir, quantized: bool)` → `ParaformerModel::load(dir, &Quantization)` -- Custom error enums → `TranscribeError::{ModelNotFound, Inference, Config, ...}` - -### Feature extraction mapping -- Paraformer uses standard fbank (Hamming window, dB scale) — use upstream `compute_mel()` with appropriate `MelConfig` -- Zipformer uses Kaldi fbank (Povey window, natural log, DC removal) — use new `kaldi_fbank.rs` -- LFR/CMVN — use upstream `features::apply_lfr` and `features::apply_cmvn` - ---- - -### Task 1: Create feature branch from upstream/main - -**Files:** None (git operations only) - -- [ ] **Step 1: Create feature branch** - -```bash -git checkout -b feat/sherpa-engines upstream/main -``` - -- [ ] **Step 2: Verify clean state** - -```bash -cargo check --features onnx -``` - -Expected: compiles clean on upstream/main - -- [ ] **Step 3: Commit (empty, branch marker)** - -No commit needed — clean branch from upstream. - ---- - -### Task 2: Add Kaldi fbank feature extraction - -**Files:** -- Create: `src/features/kaldi_fbank.rs` -- Modify: `src/features/mod.rs` - -- [ ] **Step 1: Create `src/features/kaldi_fbank.rs`** - -Port from `backup/pre-upstream-port:src/engines/zipformer_common.rs` (the `compute_fbank_kaldi` function and `FbankConfig`), adapting to upstream style: - -```rust -//! Kaldi-compatible FBank feature extraction. -//! -//! Matches the behavior of kaldi-native-fbank / sherpa-onnx for Zipformer -//! and Paraformer models that expect Kaldi-style features. - -use ndarray::Array2; -use rustfft::{num_complex::Complex, FftPlanner}; - -/// Kaldi-compatible FBank configuration. -#[derive(Debug, Clone)] -pub struct KaldiFbankConfig { - pub num_bins: usize, - pub fft_size: usize, - pub window_size: usize, - pub hop_size: usize, - pub sample_rate: u32, - pub low_freq: f32, - /// Negative means nyquist + high_freq (Kaldi convention). -400 → 7600 Hz at 16 kHz. - pub high_freq: f32, - pub preemph_coeff: f32, - pub snip_edges: bool, - pub remove_dc_offset: bool, -} - -impl Default for KaldiFbankConfig { - fn default() -> Self { - Self { - num_bins: 80, - fft_size: 512, - window_size: 400, - hop_size: 160, - sample_rate: 16000, - low_freq: 20.0, - high_freq: -400.0, - preemph_coeff: 0.97, - snip_edges: false, - remove_dc_offset: true, - } - } -} - -/// Compute Kaldi-compatible FBank features. -/// -/// Key differences from standard mel spectrogram: -/// - Povey window (Hamming^0.85) instead of plain Hamming/Hann -/// - DC offset removal per frame -/// - Preemphasis applied per frame (reverse order) -/// - snip_edges=false centers first frame and zero-pads boundaries -/// - Natural log energy (not dB) -/// - Negative high_freq interpreted as nyquist + value -/// -/// Returns `[num_frames, num_bins]`. -pub fn compute_kaldi_fbank(samples: &[f32], config: &KaldiFbankConfig) -> Array2 { - let window_size = config.window_size; - let hop_size = config.hop_size; - let fft_size = config.fft_size; - let half_fft = fft_size / 2 + 1; - - if samples.is_empty() { - return Array2::zeros((0, config.num_bins)); - } - - let num_frames = if config.snip_edges { - if samples.len() < window_size { - return Array2::zeros((0, config.num_bins)); - } - (samples.len() - window_size) / hop_size + 1 - } else { - (samples.len() + hop_size / 2) / hop_size - }; - - if num_frames == 0 { - return Array2::zeros((0, config.num_bins)); - } - - let nyquist = config.sample_rate as f32 / 2.0; - let high_freq = if config.high_freq <= 0.0 { - nyquist + config.high_freq - } else { - config.high_freq - }; - - let filterbank = mel_filterbank(config.num_bins, fft_size, config.sample_rate as f32, config.low_freq, high_freq); - - // Povey window: hamming^0.85 - let window: Vec = (0..window_size) - .map(|i| { - let hamming = 0.54 - - 0.46 - * (2.0 * std::f32::consts::PI * i as f32 / (window_size as f32 - 1.0)).cos(); - hamming.powf(0.85) - }) - .collect(); - - let mut planner = FftPlanner::new(); - let fft = planner.plan_fft_forward(fft_size); - - let mut features = Vec::with_capacity(num_frames * config.num_bins); - - for frame_idx in 0..num_frames { - let center = if config.snip_edges { - frame_idx * hop_size + window_size / 2 - } else { - frame_idx * hop_size - }; - let start = center as isize - (window_size as isize / 2); - - // Extract frame with zero-padding at boundaries - let mut frame = vec![0.0f32; window_size]; - for i in 0..window_size { - let idx = start + i as isize; - if idx >= 0 && (idx as usize) < samples.len() { - frame[i] = samples[idx as usize]; - } - } - - // Remove DC offset - if config.remove_dc_offset { - let mean: f32 = frame.iter().sum::() / window_size as f32; - for s in frame.iter_mut() { - *s -= mean; - } - } - - // Preemphasis (reverse order to avoid overwriting) - if config.preemph_coeff > 0.0 { - for i in (1..window_size).rev() { - frame[i] -= config.preemph_coeff * frame[i - 1]; - } - frame[0] *= 1.0 - config.preemph_coeff; - } - - // Apply window and FFT - let mut buffer: Vec> = frame - .iter() - .zip(window.iter()) - .map(|(&s, &w)| Complex::new(s * w, 0.0)) - .collect(); - buffer.resize(fft_size, Complex::new(0.0, 0.0)); - fft.process(&mut buffer); - - // Power spectrum - let power: Vec = buffer[..half_fft].iter().map(|c| c.norm_sqr()).collect(); - - // Apply mel filterbank and take natural log - for filter in &filterbank { - let energy: f32 = filter.iter().zip(power.iter()).map(|(&w, &p)| w * p).sum(); - features.push(if energy > f32::EPSILON { - energy.ln() - } else { - f32::EPSILON.ln() - }); - } - } - - Array2::from_shape_vec((num_frames, config.num_bins), features).unwrap() -} - -fn mel_filterbank( - num_bins: usize, - fft_size: usize, - sample_rate: f32, - low_freq: f32, - high_freq: f32, -) -> Vec> { - let half_fft = fft_size / 2 + 1; - - let hz_to_mel = |hz: f32| 1127.0 * (1.0 + hz / 700.0).ln(); - let mel_to_hz = |mel: f32| 700.0 * ((mel / 1127.0).exp() - 1.0); - - let low_mel = hz_to_mel(low_freq); - let high_mel = hz_to_mel(high_freq); - - let num_points = num_bins + 2; - let mel_points: Vec = (0..num_points) - .map(|i| low_mel + (high_mel - low_mel) * i as f32 / (num_points - 1) as f32) - .collect(); - let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); - let fft_bins: Vec = hz_points - .iter() - .map(|&hz| ((hz * fft_size as f32) / sample_rate).floor() as usize) - .collect(); - - let mut filterbank = vec![vec![0.0f32; half_fft]; num_bins]; - for (i, filter) in filterbank.iter_mut().enumerate() { - let left = fft_bins[i]; - let center = fft_bins[i + 1]; - let right = fft_bins[i + 2]; - - if center > left { - for j in left..center { - if j < half_fft { - filter[j] = (j - left) as f32 / (center - left) as f32; - } - } - } - if right > center { - for j in center..right { - if j < half_fft { - filter[j] = (right - j) as f32 / (right - center) as f32; - } - } - } - } - - filterbank -} -``` - -- [ ] **Step 2: Register in `src/features/mod.rs`** - -Add after existing exports: - -```rust -pub mod kaldi_fbank; -pub use kaldi_fbank::{compute_kaldi_fbank, KaldiFbankConfig}; -``` - -- [ ] **Step 3: Verify compilation** - -```bash -cargo check --features audio-features -``` - -Expected: PASS - -- [ ] **Step 4: Commit** - -```bash -git add src/features/kaldi_fbank.rs src/features/mod.rs -git commit -m "feat: add Kaldi-compatible fbank feature extraction" -``` - ---- - -### Task 3: Add BBPE decode module - -**Files:** -- Create: `src/decode/bbpe.rs` -- Modify: `src/decode/mod.rs` - -- [ ] **Step 1: Create `src/decode/bbpe.rs`** - -Port from `backup/pre-upstream-port:src/engines/zipformer_common.rs` (SymbolTable, BBPE mapping, normalize_text): - -```rust -//! BBPE (Byte-level BPE) symbol table for Icefall/sherpa-onnx models. -//! -//! Supports two encoding modes: -//! - BBPE: byte-to-unicode mapped tokens (Icefall zh-en models) -//! - BPE: standard sentencepiece tokens (literal UTF-8) -//! -//! Auto-detects encoding by checking for `bbpe.model` sibling file. - -use std::collections::HashMap; -use std::fs; -use std::path::Path; - -/// Token encoding mode. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum TokenEncoding { - /// Icefall BBPE: token chars are byte-to-unicode mapped, need decoding. - Bbpe, - /// Standard BPE/sentencepiece: token strings are literal UTF-8. - Bpe, -} - -/// Symbol table with BBPE/BPE decoding support. -pub struct BbpeSymbolTable { - id_to_sym: HashMap, - encoding: TokenEncoding, -} - -impl BbpeSymbolTable { - /// Load with auto-detected encoding. - /// If `bbpe.model` exists in the same directory as `path`, use BBPE; otherwise BPE. - pub fn load(path: &Path) -> Result { - let encoding = if let Some(dir) = path.parent() { - if dir.join("bbpe.model").exists() { - log::info!("Detected BBPE encoding (bbpe.model found)"); - TokenEncoding::Bbpe - } else { - log::info!("Detected standard BPE encoding (no bbpe.model)"); - TokenEncoding::Bpe - } - } else { - TokenEncoding::Bbpe - }; - Self::load_with_encoding(path, encoding) - } - - /// Load with explicit encoding. - pub fn load_with_encoding( - path: &Path, - encoding: TokenEncoding, - ) -> Result { - let contents = fs::read_to_string(path)?; - let mut id_to_sym = HashMap::new(); - - for line in contents.lines() { - let line = line.trim_end(); - if line.is_empty() { - continue; - } - // Format: "token id" (split on last whitespace; token can contain spaces) - let parts: Vec<&str> = line.rsplitn(2, |c: char| c.is_whitespace()).collect(); - if parts.len() == 2 { - if let Ok(id) = parts[0].parse::() { - id_to_sym.insert(id, parts[1].to_string()); - } - } - } - - log::info!( - "Loaded {} tokens from {:?} (encoding={:?})", - id_to_sym.len(), - path, - encoding - ); - Ok(Self { id_to_sym, encoding }) - } - - /// Decode token IDs to text. - pub fn decode(&self, token_ids: &[i32]) -> String { - match self.encoding { - TokenEncoding::Bbpe => self.decode_bbpe(token_ids), - TokenEncoding::Bpe => self.decode_bpe(token_ids), - } - } - - fn decode_bbpe(&self, token_ids: &[i32]) -> String { - let mut raw_bytes = Vec::new(); - - for &id in token_ids { - let Some(sym) = self.id_to_sym.get(&id) else { - continue; - }; - if sym.starts_with('<') && sym.ends_with('>') { - continue; - } - for c in sym.chars() { - if c == '\u{2581}' { - raw_bytes.push(b' '); - } else if let Some(byte_val) = bbpe_char_to_byte(c) { - raw_bytes.push(byte_val); - } - } - } - - let text = String::from_utf8_lossy(&raw_bytes); - normalize_text(text.trim()) - } - - fn decode_bpe(&self, token_ids: &[i32]) -> String { - let mut text = String::new(); - - for &id in token_ids { - let Some(sym) = self.id_to_sym.get(&id) else { - continue; - }; - if sym.starts_with('<') && sym.ends_with('>') { - continue; - } - text.push_str(&sym.replace('\u{2581}', " ")); - } - - normalize_text(text.trim()) - } -} - -// ---- Text normalization ---- - -fn is_cjk(c: char) -> bool { - matches!(c, - '\u{4E00}'..='\u{9FFF}' | - '\u{3400}'..='\u{4DBF}' | - '\u{F900}'..='\u{FAFF}' | - '\u{2E80}'..='\u{2EFF}' | - '\u{3000}'..='\u{303F}' | - '\u{FF00}'..='\u{FFEF}' - ) -} - -/// Remove spaces between CJK characters and lowercase English text. -fn normalize_text(text: &str) -> String { - let text = text.to_lowercase(); - let chars: Vec = text.chars().collect(); - let mut result = String::with_capacity(text.len()); - - for i in 0..chars.len() { - let c = chars[i]; - if c == ' ' { - let prev_cjk = i > 0 && is_cjk(chars[i - 1]); - let next_cjk = i + 1 < chars.len() && is_cjk(chars[i + 1]); - if prev_cjk && next_cjk { - continue; - } - } - result.push(c); - } - - result -} - -// ---- Icefall BBPE byte mapping ---- - -/// Icefall PRINTABLE_BASE_CHARS: maps byte index (0-255) to a Unicode codepoint. -const BBPE_CODEPOINTS: [u32; 256] = [ - 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, - 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, - 286, 287, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, - 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, - 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, - 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 288, 289, 290, 291, 292, - 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 308, 309, - 310, 311, 312, 313, 314, 315, 316, 317, 318, 321, 322, 323, 324, 325, 326, - 327, 328, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, - 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, - 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, - 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, - 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, - 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, - 419, 420, 421, 422, -]; - -fn bbpe_char_to_byte(c: char) -> Option { - let cp = c as u32; - if (32..=126).contains(&cp) { - return Some(cp as u8); - } - for (byte_val, &mapped_cp) in BBPE_CODEPOINTS.iter().enumerate() { - if mapped_cp == cp { - return Some(byte_val as u8); - } - } - None -} -``` - -- [ ] **Step 2: Register in `src/decode/mod.rs`** - -Add after existing exports: - -```rust -pub mod bbpe; -pub use bbpe::BbpeSymbolTable; -``` - -- [ ] **Step 3: Verify compilation** - -```bash -cargo check --features audio-features -``` - -Expected: PASS - -- [ ] **Step 4: Commit** - -```bash -git add src/decode/bbpe.rs src/decode/mod.rs -git commit -m "feat: add BBPE symbol table for Icefall/sherpa-onnx models" -``` - ---- - -### Task 4: Add ParaformerModel - -**Files:** -- Create: `src/onnx/paraformer/mod.rs` -- Modify: `src/onnx/mod.rs` - -- [ ] **Step 1: Create `src/onnx/paraformer/mod.rs`** - -Port from backup branch, adapting to upstream patterns. Key changes: -- Use `session::create_session()` instead of manual session builder -- Use `session::resolve_model_path()` for quantization -- Use `session::read_metadata_i32/float_vec()` for metadata -- Use upstream `features::compute_mel()` with `MelConfig` (Paraformer uses standard Hamming window fbank, NOT Kaldi fbank) -- Use upstream `features::apply_lfr()` and `features::apply_cmvn()` -- Return `TranscribeError` instead of custom error enum -- Implement `SpeechModel` trait -- Use ort rc.12 API (`session.inputs()`, `input.name()`, etc.) - -The complete file should include: -1. `CAPABILITIES` const -2. `ParaformerParams` struct (empty for now — Paraformer is language-auto) -3. `ParaformerModel` struct with session, symbol_table, metadata, cmvn, I/O names -4. `ParaformerModel::load(dir, &Quantization)` constructor -5. Paraformer-specific `SymbolTable` (inline, handles `@@` joining and `▁` markers — different from BBPE) -6. Metadata parsing via `session::read_metadata_i32` -7. CMVN loading from ONNX metadata or `am.mvn` file -8. `compute_features()` → `compute_mel()` + `apply_lfr()` + `apply_cmvn()` -9. `forward()` → run ONNX session -10. `decode_logits()` → argmax with eos/blank/sos filtering -11. `SpeechModel` impl with `transcribe_raw()` - -**Important Paraformer-specific details:** -- Paraformer uses dB scale fbank (10*log10), NOT natural log — use `MelConfig` with `pre_emphasis: None` and standard Hamming window, then manually apply 10*log10 scaling. Actually, looking at the old code more carefully: it uses `10.0 * sum.log10()` with `-80.0` floor. The upstream `compute_mel` with `pre_emphasis: None` uses `ln()`. We need to match the original behavior. -- Solution: Use upstream `compute_mel` with custom `MelConfig{pre_emphasis: None, ...}` — BUT upstream's `compute_mel_spectrogram` uses `ln()`, not `10*log10`. We need to either (a) modify the output, or (b) implement inline. Option (b) is safer to avoid breaking existing models. Implement a private `compute_paraformer_fbank()` inside the module that matches the original exactly. -- LFR default: window_size=7, window_shift=6 (from ONNX metadata) -- CMVN: mean subtraction only (old code uses `apply_mean_cmvn` which subtracts mean; upstream `apply_cmvn` multiplies by inv_stddev too). For Paraformer we only have neg_mean, no inv_stddev. So we do mean-only CMVN inline. -- Symbol table: Paraformer tokens use `@@` for subword joining and `▁` for spaces, plus special tokens ``, ``, ``, ``. This is different from both upstream's SymbolTable and the BBPE SymbolTable. Keep it inline in the module. - -- [ ] **Step 2: Register in `src/onnx/mod.rs`** - -Add: -```rust -pub mod paraformer; -``` - -- [ ] **Step 3: Verify compilation** - -```bash -cargo check --features onnx -``` - -Expected: PASS - -- [ ] **Step 4: Commit** - -```bash -git add src/onnx/paraformer/ src/onnx/mod.rs -git commit -m "feat: add Paraformer ONNX engine" -``` - ---- - -### Task 5: Add ZipformerCtcModel - -**Files:** -- Create: `src/onnx/zipformer_ctc/mod.rs` -- Modify: `src/onnx/mod.rs` - -- [ ] **Step 1: Create `src/onnx/zipformer_ctc/mod.rs`** - -Port from backup branch. Key adaptations: -- Use `session::create_session()` for session creation -- Use `compute_kaldi_fbank()` from `features::kaldi_fbank` -- Use upstream `ctc_greedy_decode()` from `decode::ctc` — BUT note: upstream CTC takes `ArrayView3` with shape [batch, time, vocab] and `&[i64]` lengths. Our old code had custom CTC with `Array2`. Need to reshape to 3D for upstream API. -- Use `BbpeSymbolTable` from `decode::bbpe` for token decoding -- Model file discovery: keep our smart fallback logic (scan directory for *.onnx) but also try `session::resolve_model_path()` first -- Streaming model rejection: keep the `cached_*` input detection -- Return `TranscribeError` -- Implement `SpeechModel` trait - -- [ ] **Step 2: Register in `src/onnx/mod.rs`** - -Add: -```rust -pub mod zipformer_ctc; -``` - -- [ ] **Step 3: Verify compilation** - -```bash -cargo check --features onnx -``` - -Expected: PASS - -- [ ] **Step 4: Commit** - -```bash -git add src/onnx/zipformer_ctc/ src/onnx/mod.rs -git commit -m "feat: add Zipformer CTC ONNX engine" -``` - ---- - -### Task 6: Add ZipformerTransducerModel - -**Files:** -- Create: `src/onnx/zipformer_transducer/mod.rs` -- Modify: `src/onnx/mod.rs` - -- [ ] **Step 1: Create `src/onnx/zipformer_transducer/mod.rs`** - -Port from backup branch. This is the most complex engine (3 sessions): -- Use `session::create_session()` for all 3 sessions -- Use `compute_kaldi_fbank()` for features -- Use `BbpeSymbolTable` for token decoding -- Keep the multi-file model discovery logic (`find_model_file` for encoder/decoder/joiner with various naming patterns) -- Keep streaming model rejection -- Keep the RNN-T greedy search decoding loop (no upstream equivalent) -- context_size=2 hardcoded -- Return `TranscribeError` -- Implement `SpeechModel` trait - -**Important:** The transducer's `find_model_file` looks for `{component}-*.{suffix}.onnx` patterns (e.g., `encoder-epoch-34-avg-19.int8.onnx`). This is unique to sherpa-onnx transducer models and must be preserved. - -- [ ] **Step 2: Register in `src/onnx/mod.rs`** - -Add: -```rust -pub mod zipformer_transducer; -``` - -- [ ] **Step 3: Verify compilation** - -```bash -cargo check --features onnx -``` - -Expected: PASS - -- [ ] **Step 4: Commit** - -```bash -git add src/onnx/zipformer_transducer/ src/onnx/mod.rs -git commit -m "feat: add Zipformer Transducer ONNX engine" -``` - ---- - -### Task 7: Add PunctModel - -**Files:** -- Create: `src/punct.rs` -- Modify: `src/lib.rs` -- Modify: `Cargo.toml` - -- [ ] **Step 1: Add `punct` feature to `Cargo.toml`** - -In `[features]` section, add: -```toml -# Neural punctuation restoration (CT-Transformer) -punct = ["dep:ort", "dep:ndarray"] -``` - -Update `all` feature to include `punct`: -```toml -all = ["onnx", "whisper-cpp", "whisperfile", "openai", "punct"] -``` - -- [ ] **Step 2: Create `src/punct.rs`** - -Port from backup branch with these adaptations: -- Use `session::create_session()` instead of manual session builder (but note: punct uses `#[cfg(feature = "punct")]` not `#[cfg(feature = "onnx")]`, and `session` module is under `onnx` feature. So we need to build the session manually for punct, OR gate punct under onnx.) -- **Decision:** Gate punct session creation manually (like `vad-silero` does — it also uses ort directly without the onnx feature). Use `ort` directly: - -```rust -use ort::session::builder::GraphOptimizationLevel; -use ort::session::Session; -``` - -- Keep the custom `PunctError` enum (it's not a SpeechModel, so TranscribeError doesn't fit perfectly. But for consistency, convert to `TranscribeError`.) -- **Decision:** Use `TranscribeError` for consistency with the rest of the crate. The From impls already exist for ort::Error, serde_json::Error, and io::Error. -- Use ort rc.12 API for session inputs/outputs -- Keep all inference logic unchanged - -- [ ] **Step 3: Register in `src/lib.rs`** - -Add after existing module declarations: -```rust -#[cfg(feature = "punct")] -pub mod punct; -``` - -- [ ] **Step 4: Verify compilation** - -```bash -cargo check --features punct -cargo check --features "onnx,punct" -``` - -Expected: both PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/punct.rs src/lib.rs Cargo.toml -git commit -m "feat: add neural punctuation restoration model" -``` - ---- - -### Task 8: Add examples - -**Files:** -- Create: `examples/paraformer.rs` -- Create: `examples/zipformer_ctc.rs` -- Create: `examples/zipformer_transducer.rs` -- Modify: `Cargo.toml` - -- [ ] **Step 1: Create examples** - -Follow the upstream pattern from `examples/gigaam.rs`. Each example: -- Accepts model_dir and wav_path as positional args with defaults -- Supports `--int8` flag -- Shows load time, transcribe time, real-time speedup -- Displays text and segments - -Default model paths: -- Paraformer: `models/sherpa-onnx-paraformer-zh-2025-10-07` -- Zipformer CTC: `models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16` -- Zipformer Transducer: `models/sherpa-onnx-zipformer-zh-en-2023-11-22` - -Default wav: `samples/zh.wav` - -- [ ] **Step 2: Add example declarations to `Cargo.toml`** - -```toml -[[example]] -name = "paraformer" -required-features = ["onnx"] - -[[example]] -name = "zipformer_ctc" -required-features = ["onnx"] - -[[example]] -name = "zipformer_transducer" -required-features = ["onnx"] -``` - -- [ ] **Step 3: Verify examples compile** - -```bash -cargo check --example paraformer --features onnx -cargo check --example zipformer_ctc --features onnx -cargo check --example zipformer_transducer --features onnx -``` - -Expected: all PASS - -- [ ] **Step 4: Commit** - -```bash -git add examples/paraformer.rs examples/zipformer_ctc.rs examples/zipformer_transducer.rs Cargo.toml -git commit -m "feat: add examples for Paraformer and Zipformer engines" -``` - ---- - -### Task 9: Add tests - -**Files:** -- Create: `tests/paraformer.rs` -- Create: `tests/zipformer_ctc.rs` -- Create: `tests/zipformer_transducer.rs` -- Modify: `Cargo.toml` - -- [ ] **Step 1: Create test files** - -Follow upstream pattern from `tests/gigaam.rs`. Each test: -- Uses `mod common;` for `require_paths` -- Skips if model/wav not found (graceful skip, not failure) -- Loads model with `Quantization::Int8` -- Transcribes a test WAV -- Asserts expected output text - -- [ ] **Step 2: Add test declarations to `Cargo.toml`** - -```toml -[[test]] -name = "paraformer" -required-features = ["onnx"] - -[[test]] -name = "zipformer_ctc" -required-features = ["onnx"] - -[[test]] -name = "zipformer_transducer" -required-features = ["onnx"] -``` - -- [ ] **Step 3: Verify tests compile** - -```bash -cargo test --no-run --features onnx -``` - -Expected: PASS (tests compile; may skip at runtime if models not present) - -- [ ] **Step 4: Commit** - -```bash -git add tests/paraformer.rs tests/zipformer_ctc.rs tests/zipformer_transducer.rs Cargo.toml -git commit -m "test: add tests for Paraformer and Zipformer engines" -``` - ---- - -### Task 10: Full verification - -- [ ] **Step 1: Verify all features compile** - -```bash -cargo check --features onnx -cargo check --features punct -cargo check --features "onnx,punct" -cargo check --features all -``` - -Expected: all PASS - -- [ ] **Step 2: Run cargo clippy** - -```bash -cargo clippy --features "onnx,punct" -- -D warnings -``` - -Expected: PASS (no warnings) - -- [ ] **Step 3: Run cargo fmt** - -```bash -cargo fmt --check -``` - -Expected: PASS - -- [ ] **Step 4: Run tests with models (if available)** - -```bash -cargo test --features onnx -- --nocapture -``` - -- [ ] **Step 5: Run examples with models (if available)** - -```bash -cargo run --example paraformer --features onnx -- models/sherpa-onnx-paraformer-zh-2025-10-07 samples/zh.wav --int8 -cargo run --example zipformer_ctc --features onnx -- models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16 samples/zh.wav --int8 -cargo run --example zipformer_transducer --features onnx -- models/sherpa-onnx-zipformer-zh-en-2023-11-22 samples/zh.wav --int8 -``` - -- [ ] **Step 6: Final commit (if any fmt/clippy fixes)** - -```bash -git add -A -git commit -m "chore: fix clippy warnings and formatting" -``` diff --git a/examples/llm_postprocess.rs b/examples/llm_postprocess.rs deleted file mode 100644 index a875c0e..0000000 --- a/examples/llm_postprocess.rs +++ /dev/null @@ -1,89 +0,0 @@ -//! LLM-based ASR post-processing example using Qwen2.5-0.5B (GGUF). -//! -//! Demonstrates loading a quantized Qwen2.5 model via the library and using it -//! to add punctuation and correct errors in ASR output. -//! -//! Usage: -//! cargo run --example llm_postprocess --features llm-postprocess --release -//! cargo run --example llm_postprocess --features llm-postprocess --release -- \ -//! --model models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf \ -//! --tokenizer models/qwen2.5-0.5b/tokenizer.json \ -//! --text "今天天气很好我们去公圆玩吧他说号的" - -use std::io::Write; -use std::path::Path; -use std::time::Instant; - -use transcribe_rs::llm_postprocess::LlmPostProcessor; - -const DEFAULT_MODEL: &str = "models/qwen2.5-0.5b/qwen2.5-0.5b-instruct-q4_k_m.gguf"; -const DEFAULT_TOKENIZER: &str = "models/qwen2.5-0.5b/tokenizer.json"; -const DEFAULT_TEXT: &str = "今天天气很好我们去公圆玩吧他说号的"; - -fn main() -> Result<(), Box> { - let args: Vec = std::env::args().collect(); - let (model_path, tokenizer_path, input_text) = parse_args(&args); - - println!("=== LLM 后处理验证 (Qwen2.5-0.5B GGUF) ===\n"); - - // 1. Load model - print!("加载模型... "); - std::io::stdout().flush()?; - let load_start = Instant::now(); - - let mut processor = - LlmPostProcessor::from_files(Path::new(&model_path), Path::new(&tokenizer_path))?; - - let load_time = load_start.elapsed(); - println!("完成 ({:.2?})", load_time); - - // 2. Process text - println!(); - println!("原始文本: {}", input_text); - - let gen_start = Instant::now(); - let result = processor.process(&input_text)?; - let gen_time = gen_start.elapsed(); - - println!("修正文本: {}", result); - println!(); - println!("--- 统计 ---"); - println!("耗时: {:.2?}", gen_time); - println!("模型加载: {:.2?}", load_time); - - Ok(()) -} - -fn parse_args(args: &[String]) -> (String, String, String) { - let mut model = DEFAULT_MODEL.to_string(); - let mut tokenizer = DEFAULT_TOKENIZER.to_string(); - let mut text = DEFAULT_TEXT.to_string(); - - let mut i = 1; - while i < args.len() { - match args[i].as_str() { - "--model" => { - i += 1; - if i < args.len() { - model = args[i].clone(); - } - } - "--tokenizer" => { - i += 1; - if i < args.len() { - tokenizer = args[i].clone(); - } - } - "--text" => { - i += 1; - if i < args.len() { - text = args[i].clone(); - } - } - _ => {} - } - i += 1; - } - - (model, tokenizer, text) -} diff --git a/src/llm_postprocess.rs b/src/llm_postprocess.rs deleted file mode 100644 index cfd7bd9..0000000 --- a/src/llm_postprocess.rs +++ /dev/null @@ -1,273 +0,0 @@ -//! LLM-based ASR post-processing using quantized Qwen2.5 (GGUF). -//! -//! Uses a quantized Qwen2.5-0.5B model via candle to add punctuation and -//! correct homophones in ASR output. All inference runs on CPU with no -//! dynamic library dependencies. -//! -//! # Feature gate -//! -//! This module requires the `llm-postprocess` feature: -//! -//! ```toml -//! [dependencies] -//! transcribe-rs = { version = "0.2", features = ["llm-postprocess"] } -//! ``` -//! -//! # Model files -//! -//! You need a GGUF-quantized Qwen2.5 model and its tokenizer: -//! -//! - `qwen2.5-0.5b-instruct-q4_k_m.gguf` (~350 MB) -//! - `tokenizer.json` (from HuggingFace Qwen2.5-0.5B-Instruct) -//! -//! Place them in a single directory (e.g. `models/qwen2.5-0.5b/`). -//! -//! # Usage -//! -//! **Reusable processor** (recommended for multiple calls): -//! -//! ```ignore -//! use std::path::Path; -//! use transcribe_rs::llm_postprocess::LlmPostProcessor; -//! -//! let mut proc = LlmPostProcessor::new(Path::new("models/qwen2.5-0.5b/"))?; -//! -//! let corrected = proc.process("今天天气很好我们去公圆玩吧他说号的")?; -//! println!("{}", corrected); -//! // => "今天天气很好,我们去公园玩吧,他说好的。" -//! ``` -//! -//! **One-shot convenience function**: -//! -//! ```ignore -//! use std::path::Path; -//! use transcribe_rs::llm_postprocess::llm_postprocess; -//! -//! let corrected = llm_postprocess( -//! "今天天气很好我们去公圆玩吧", -//! Path::new("models/qwen2.5-0.5b/"), -//! )?; -//! ``` -//! -//! **Custom system prompt**: -//! -//! ```ignore -//! let corrected = proc.process_with_prompt( -//! "the wether is grate today", -//! "You are a post-processing assistant. Fix punctuation and spelling.", -//! )?; -//! ``` -//! -//! # Pipeline integration -//! -//! Typical ASR post-processing pipeline: -//! -//! 1. **CT-Transformer** (`punct` feature) — fast punctuation restoration (~7 ms) -//! 2. **LLM post-process** (`llm-postprocess` feature) — deep correction (~1-3 s) -//! -//! ```ignore -//! // Step 1: fast punctuation -//! let mut punct = transcribe_rs::PunctModel::new(Path::new("models/punct/"))?; -//! let text = punct.add_punctuation(&raw_asr_text); -//! -//! // Step 2: LLM correction (optional, slower but more accurate) -//! let text = proc.process(&text)?; -//! ``` - -use std::path::Path; - -use candle_core::quantized::gguf_file; -use candle_core::{Device, Tensor}; -use candle_transformers::models::quantized_qwen2::ModelWeights; -use tokenizers::Tokenizer; - -const MAX_TOKENS: usize = 256; -const EOS_TOKEN: &str = "<|im_end|>"; -const DEFAULT_EOS_ID: u32 = 151645; - -const DEFAULT_SYSTEM_PROMPT: &str = "你是语音识别后处理助手。用户输入是语音识别的原始输出,\ -可能缺少标点、含有同音错别字。请添加正确的标点符号,并将同音错别字纠正为正确的字词。\ -只输出纠正后的完整文本。"; - -#[derive(thiserror::Error, Debug)] -pub enum LlmPostProcessError { - #[error("candle error: {0}")] - Candle(#[from] candle_core::Error), - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - #[error("tokenizer error: {0}")] - Tokenizer(String), - #[error("model file not found: {0}")] - ModelNotFound(String), - #[error("tokenizer file not found: {0}")] - TokenizerNotFound(String), - #[error("no GGUF file found in directory: {0}")] - NoGgufFile(String), -} - -/// LLM-based post-processor holding a quantized Qwen2.5 model and tokenizer. -/// -/// Reuse a single instance for multiple calls to avoid repeated model loading. -pub struct LlmPostProcessor { - model: ModelWeights, - tokenizer: Tokenizer, - device: Device, - eos_token_id: u32, -} - -impl LlmPostProcessor { - /// Load from a model directory containing `*.gguf` and `tokenizer.json`. - pub fn new(model_dir: &Path) -> Result { - let gguf_path = find_gguf_file(model_dir)?; - let tokenizer_path = model_dir.join("tokenizer.json"); - if !tokenizer_path.exists() { - return Err(LlmPostProcessError::TokenizerNotFound( - tokenizer_path.display().to_string(), - )); - } - Self::from_files(&gguf_path, &tokenizer_path) - } - - /// Load from explicit file paths. - pub fn from_files( - gguf_path: &Path, - tokenizer_path: &Path, - ) -> Result { - if !gguf_path.exists() { - return Err(LlmPostProcessError::ModelNotFound( - gguf_path.display().to_string(), - )); - } - if !tokenizer_path.exists() { - return Err(LlmPostProcessError::TokenizerNotFound( - tokenizer_path.display().to_string(), - )); - } - - let device = Device::Cpu; - - let tokenizer = Tokenizer::from_file(tokenizer_path) - .map_err(|e| LlmPostProcessError::Tokenizer(e.to_string()))?; - - let mut file = std::fs::File::open(gguf_path)?; - let content = gguf_file::Content::read(&mut file)?; - let model = ModelWeights::from_gguf(content, &mut file, &device)?; - - let eos_token_id = tokenizer.token_to_id(EOS_TOKEN).unwrap_or(DEFAULT_EOS_ID); - - Ok(Self { - model, - tokenizer, - device, - eos_token_id, - }) - } - - /// Process ASR text using the default system prompt. - pub fn process(&mut self, text: &str) -> Result { - self.process_with_prompt(text, DEFAULT_SYSTEM_PROMPT) - } - - /// Process ASR text using a custom system prompt. - pub fn process_with_prompt( - &mut self, - text: &str, - system_prompt: &str, - ) -> Result { - let prompt = format!( - "<|im_start|>system\n{system_prompt}<|im_end|>\n\ - <|im_start|>user\n\ - 请纠正以下语音识别文本中的标点和错别字:\n\ - {text}<|im_end|>\n\ - <|im_start|>assistant\n" - ); - - let encoding = self - .tokenizer - .encode(prompt.as_str(), true) - .map_err(|e| LlmPostProcessError::Tokenizer(e.to_string()))?; - let prompt_tokens = encoding.get_ids().to_vec(); - let prompt_len = prompt_tokens.len(); - - // Feed prompt through the model - let input = Tensor::new(prompt_tokens.as_slice(), &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input, 0)?; - let last_logits = extract_last_logits(&logits)?; - let mut next_token = sample_greedy(&last_logits)?; - - let mut output_text = String::new(); - let mut generated_tokens: usize = 0; - - for _ in 0..MAX_TOKENS { - if next_token == self.eos_token_id { - break; - } - - generated_tokens += 1; - - if let Ok(decoded) = self.tokenizer.decode(&[next_token], false) { - output_text.push_str(&decoded); - } - - // Forward pass for next token - let input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; - let pos = prompt_len + generated_tokens - 1; - let logits = self.model.forward(&input, pos)?; - let last_logits = extract_last_logits(&logits)?; - next_token = sample_greedy(&last_logits)?; - } - - Ok(output_text) - } -} - -/// Convenience function that loads the model and processes text in one call. -/// -/// For repeated use, prefer creating an [`LlmPostProcessor`] instance directly. -pub fn llm_postprocess(text: &str, model_dir: &Path) -> Result { - let mut processor = LlmPostProcessor::new(model_dir)?; - processor.process(text) -} - -// --------------------------------------------------------------------------- -// Internal helpers -// --------------------------------------------------------------------------- - -/// Find the first `*.gguf` file in a directory. -fn find_gguf_file(dir: &Path) -> Result { - if dir.is_file() && dir.extension().is_some_and(|e| e == "gguf") { - return Ok(dir.to_path_buf()); - } - - let entries = std::fs::read_dir(dir)?; - for entry in entries { - let entry = entry?; - let path = entry.path(); - if path.extension().is_some_and(|e| e == "gguf") { - return Ok(path); - } - } - - Err(LlmPostProcessError::NoGgufFile(dir.display().to_string())) -} - -/// Extract the last position's logits, handling 1D/2D/3D tensor shapes. -fn extract_last_logits(logits: &Tensor) -> Result { - match logits.dims().len() { - 3 => { - let logits = logits.squeeze(0)?; - logits.get(logits.dim(0)? - 1) - } - 2 => logits.get(logits.dim(0)? - 1), - 1 => Ok(logits.clone()), - _ => Err(candle_core::Error::Msg(format!( - "unexpected logits shape: {:?}", - logits.dims() - ))), - } -} - -/// Greedy (argmax) token sampling. -fn sample_greedy(logits: &Tensor) -> Result { - logits.argmax(0)?.to_scalar::() -} diff --git a/src/punct.rs b/src/punct.rs index f273755..97d69fd 100644 --- a/src/punct.rs +++ b/src/punct.rs @@ -262,7 +262,6 @@ impl PunctModel { .as_slice() .unwrap() .chunks(num_classes) - .skip(0) // batch dim handled by taking first batch only .take(seq_len) .map(|row| { row.iter()