diff --git a/Cargo.toml b/Cargo.toml index d1395be..3cd77a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,9 @@ openai = ["dep:async-openai", "dep:tokio", "dep:async-trait"] # Silero neural VAD (requires silero_vad_v4.onnx model file) vad-silero = ["dep:ort", "dep:ndarray"] +# Neural punctuation restoration (CT-Transformer) +punct = ["dep:ort", "dep:ndarray"] + # --- ORT Accelerators --- # Note: ort-cuda pulls in the CUDA execution provider, which adds ~800 MB+ # to the ORT binary and requires a CUDA toolkit / cuDNN installation at runtime. @@ -44,7 +47,7 @@ ort-webgpu = ["onnx", "ort/webgpu"] ort-accel = ["ort-cuda", "ort-directml", "ort-rocm", "ort-coreml", "ort-webgpu"] # Convenience -all = ["onnx", "whisper-cpp", "whisperfile", "openai"] +all = ["onnx", "whisper-cpp", "whisperfile", "openai", "punct"] [dependencies] # Always required @@ -115,6 +118,18 @@ required-features = ["whisperfile"] name = "openai" required-features = ["openai"] +[[example]] +name = "paraformer" +required-features = ["onnx"] + +[[example]] +name = "zipformer_ctc" +required-features = ["onnx"] + +[[example]] +name = "zipformer_transducer" +required-features = ["onnx"] + [dev-dependencies] once_cell = "1.21.3" @@ -158,3 +173,15 @@ required-features = ["onnx", "vad-silero"] [[test]] name = "vad_silero" required-features = ["vad-silero"] + +[[test]] +name = "paraformer" +required-features = ["onnx"] + +[[test]] +name = "zipformer_ctc" +required-features = ["onnx"] + +[[test]] +name = "zipformer_transducer" +required-features = ["onnx"] diff --git a/examples/paraformer.rs b/examples/paraformer.rs new file mode 100644 index 0000000..550bf2a --- /dev/null +++ b/examples/paraformer.rs @@ -0,0 +1,87 @@ +use std::path::PathBuf; +use std::time::Instant; + +use transcribe_rs::onnx::paraformer::ParaformerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +fn get_audio_duration(path: &PathBuf) -> Result> { + let reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let duration = reader.duration() as f64 / spec.sample_rate as f64; + Ok(duration) +} + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = std::env::args().collect(); + let positional: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); + + let int8 = args.iter().any(|a| a == "--int8"); + let model_path = PathBuf::from( + positional + .first() + .map(|s| s.as_str()) + .unwrap_or("models/sherpa-onnx-paraformer-zh-2025-10-07"), + ); + let wav_path = PathBuf::from( + positional + .get(1) + .map(|s| s.as_str()) + .unwrap_or("samples/zh.wav"), + ); + + let audio_duration = get_audio_duration(&wav_path)?; + println!("Audio duration: {:.2}s", audio_duration); + + let quantization = if int8 { + Quantization::Int8 + } else { + Quantization::FP32 + }; + + println!("Using Paraformer engine"); + println!( + "Loading model: {:?} (quantization: {})", + model_path, + if int8 { "int8" } else { "fp32" } + ); + + let load_start = Instant::now(); + let mut model = ParaformerModel::load(&model_path, &quantization)?; + let load_duration = load_start.elapsed(); + println!("Model loaded in {:.2?}", load_duration); + + println!("Transcribing file: {:?}", wav_path); + let transcribe_start = Instant::now(); + + let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?; + let transcribe_duration = transcribe_start.elapsed(); + println!("Transcription completed in {:.2?}", transcribe_duration); + + let speedup_factor = audio_duration / transcribe_duration.as_secs_f64(); + println!( + "Real-time speedup: {:.2}x faster than real-time", + speedup_factor + ); + + println!("Transcription result:"); + println!("{}", result.text); + + if let Some(segments) = result.segments { + println!("\nSegments:"); + for segment in segments { + println!( + "[{:.2}s - {:.2}s]: {}", + segment.start, segment.end, segment.text + ); + } + } + + Ok(()) +} diff --git a/examples/zipformer_ctc.rs b/examples/zipformer_ctc.rs new file mode 100644 index 0000000..bcf2eec --- /dev/null +++ b/examples/zipformer_ctc.rs @@ -0,0 +1,87 @@ +use std::path::PathBuf; +use std::time::Instant; + +use transcribe_rs::onnx::zipformer_ctc::ZipformerCtcModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +fn get_audio_duration(path: &PathBuf) -> Result> { + let reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let duration = reader.duration() as f64 / spec.sample_rate as f64; + Ok(duration) +} + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = std::env::args().collect(); + let positional: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); + + let int8 = args.iter().any(|a| a == "--int8"); + let model_path = PathBuf::from( + positional + .first() + .map(|s| s.as_str()) + .unwrap_or("models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16"), + ); + let wav_path = PathBuf::from( + positional + .get(1) + .map(|s| s.as_str()) + .unwrap_or("samples/zh.wav"), + ); + + let audio_duration = get_audio_duration(&wav_path)?; + println!("Audio duration: {:.2}s", audio_duration); + + let quantization = if int8 { + Quantization::Int8 + } else { + Quantization::FP32 + }; + + println!("Using Zipformer CTC engine"); + println!( + "Loading model: {:?} (quantization: {})", + model_path, + if int8 { "int8" } else { "fp32" } + ); + + let load_start = Instant::now(); + let mut model = ZipformerCtcModel::load(&model_path, &quantization)?; + let load_duration = load_start.elapsed(); + println!("Model loaded in {:.2?}", load_duration); + + println!("Transcribing file: {:?}", wav_path); + let transcribe_start = Instant::now(); + + let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?; + let transcribe_duration = transcribe_start.elapsed(); + println!("Transcription completed in {:.2?}", transcribe_duration); + + let speedup_factor = audio_duration / transcribe_duration.as_secs_f64(); + println!( + "Real-time speedup: {:.2}x faster than real-time", + speedup_factor + ); + + println!("Transcription result:"); + println!("{}", result.text); + + if let Some(segments) = result.segments { + println!("\nSegments:"); + for segment in segments { + println!( + "[{:.2}s - {:.2}s]: {}", + segment.start, segment.end, segment.text + ); + } + } + + Ok(()) +} diff --git a/examples/zipformer_transducer.rs b/examples/zipformer_transducer.rs new file mode 100644 index 0000000..8067ba6 --- /dev/null +++ b/examples/zipformer_transducer.rs @@ -0,0 +1,87 @@ +use std::path::PathBuf; +use std::time::Instant; + +use transcribe_rs::onnx::zipformer_transducer::ZipformerTransducerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +fn get_audio_duration(path: &PathBuf) -> Result> { + let reader = hound::WavReader::open(path)?; + let spec = reader.spec(); + let duration = reader.duration() as f64 / spec.sample_rate as f64; + Ok(duration) +} + +fn main() -> Result<(), Box> { + env_logger::init(); + + let args: Vec = std::env::args().collect(); + let positional: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); + + let int8 = args.iter().any(|a| a == "--int8"); + let model_path = PathBuf::from( + positional + .first() + .map(|s| s.as_str()) + .unwrap_or("models/sherpa-onnx-zipformer-zh-en-2023-11-22"), + ); + let wav_path = PathBuf::from( + positional + .get(1) + .map(|s| s.as_str()) + .unwrap_or("samples/zh.wav"), + ); + + let audio_duration = get_audio_duration(&wav_path)?; + println!("Audio duration: {:.2}s", audio_duration); + + let quantization = if int8 { + Quantization::Int8 + } else { + Quantization::FP32 + }; + + println!("Using Zipformer Transducer engine"); + println!( + "Loading model: {:?} (quantization: {})", + model_path, + if int8 { "int8" } else { "fp32" } + ); + + let load_start = Instant::now(); + let mut model = ZipformerTransducerModel::load(&model_path, &quantization)?; + let load_duration = load_start.elapsed(); + println!("Model loaded in {:.2?}", load_duration); + + println!("Transcribing file: {:?}", wav_path); + let transcribe_start = Instant::now(); + + let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?; + let transcribe_duration = transcribe_start.elapsed(); + println!("Transcription completed in {:.2?}", transcribe_duration); + + let speedup_factor = audio_duration / transcribe_duration.as_secs_f64(); + println!( + "Real-time speedup: {:.2}x faster than real-time", + speedup_factor + ); + + println!("Transcription result:"); + println!("{}", result.text); + + if let Some(segments) = result.segments { + println!("\nSegments:"); + for segment in segments { + println!( + "[{:.2}s - {:.2}s]: {}", + segment.start, segment.end, segment.text + ); + } + } + + Ok(()) +} diff --git a/src/decode/bbpe.rs b/src/decode/bbpe.rs new file mode 100644 index 0000000..29d667f --- /dev/null +++ b/src/decode/bbpe.rs @@ -0,0 +1,177 @@ +//! BBPE (Byte-level BPE) symbol table for Icefall/sherpa-onnx Zipformer models. +//! +//! Handles byte-to-unicode mapping and auto-detects encoding mode based on +//! whether a `bbpe.model` file is present alongside the tokens file. + +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +/// Whether tokens use BBPE byte encoding or standard BPE (literal UTF-8). +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TokenEncoding { + Bbpe, + Bpe, +} + +/// Symbol table for Icefall/sherpa-onnx Zipformer models with BBPE support. +pub struct BbpeSymbolTable { + id_to_sym: HashMap, + encoding: TokenEncoding, +} + +impl BbpeSymbolTable { + /// Load a symbol table, auto-detecting encoding mode. + /// + /// If a `bbpe.model` file exists in the same directory as `path`, + /// BBPE encoding is used; otherwise standard BPE is assumed. + pub fn load(path: &Path) -> Result { + Self::load_autodetect(path) + } + + /// Load with explicit auto-detection of encoding based on sibling files. + pub fn load_autodetect(path: &Path) -> Result { + let encoding = if let Some(dir) = path.parent() { + if dir.join("bbpe.model").exists() { + log::debug!("Detected BBPE encoding (bbpe.model found)"); + TokenEncoding::Bbpe + } else { + log::debug!("Detected standard BPE encoding (no bbpe.model)"); + TokenEncoding::Bpe + } + } else { + TokenEncoding::Bbpe + }; + Self::load_with_encoding(path, encoding) + } + + /// Load with an explicitly specified encoding. + pub fn load_with_encoding( + path: &Path, + encoding: TokenEncoding, + ) -> Result { + let contents = fs::read_to_string(path)?; + let mut id_to_sym = HashMap::new(); + for line in contents.lines() { + let line = line.trim_end(); + if line.is_empty() { + continue; + } + let parts: Vec<&str> = line.rsplitn(2, |c: char| c.is_whitespace()).collect(); + if parts.len() != 2 { + continue; + } + if let Ok(id) = parts[0].parse::() { + id_to_sym.insert(id, parts[1].to_string()); + } + } + Ok(Self { + id_to_sym, + encoding, + }) + } + + /// Look up a symbol by token ID. + pub fn get(&self, id: i32) -> Option<&str> { + self.id_to_sym.get(&id).map(|s| s.as_str()) + } + + /// Decode a sequence of token IDs to a UTF-8 string. + pub fn decode(&self, token_ids: &[i32]) -> String { + match self.encoding { + TokenEncoding::Bbpe => self.decode_bbpe(token_ids), + TokenEncoding::Bpe => self.decode_bpe(token_ids), + } + } + + fn decode_bbpe(&self, token_ids: &[i32]) -> String { + let mut raw_bytes = Vec::new(); + for &id in token_ids { + let Some(sym) = self.get(id) else { continue }; + if sym.starts_with('<') && sym.ends_with('>') { + continue; + } + for c in sym.chars() { + if c == '\u{2581}' { + raw_bytes.push(b' '); + } else if let Some(byte_val) = bbpe_char_to_byte(c) { + raw_bytes.push(byte_val); + } + } + } + let text = String::from_utf8_lossy(&raw_bytes); + normalize_text(text.trim()) + } + + fn decode_bpe(&self, token_ids: &[i32]) -> String { + let mut text = String::new(); + for &id in token_ids { + let Some(sym) = self.get(id) else { continue }; + if sym.starts_with('<') && sym.ends_with('>') { + continue; + } + text.push_str(&sym.replace('\u{2581}', " ")); + } + normalize_text(text.trim()) + } +} + +fn is_cjk(c: char) -> bool { + matches!(c, + '\u{4E00}'..='\u{9FFF}' | + '\u{3400}'..='\u{4DBF}' | + '\u{F900}'..='\u{FAFF}' | + '\u{2E80}'..='\u{2EFF}' | + '\u{3000}'..='\u{303F}' | + '\u{FF00}'..='\u{FFEF}' + ) +} + +fn normalize_text(text: &str) -> String { + let text = text.to_lowercase(); + let chars: Vec = text.chars().collect(); + let mut result = String::with_capacity(text.len()); + for i in 0..chars.len() { + let c = chars[i]; + if c == ' ' { + let prev_cjk = i > 0 && is_cjk(chars[i - 1]); + let next_cjk = i + 1 < chars.len() && is_cjk(chars[i + 1]); + if prev_cjk && next_cjk { + continue; + } + } + result.push(c); + } + result +} + +/// BBPE codepoint table: maps byte value (index) to Unicode codepoint. +const BBPE_CODEPOINTS: [u32; 256] = [ + 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, + 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, + 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 308, + 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 321, 322, 323, 324, 325, 326, 327, 328, 330, + 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, + 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, + 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, + 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, + 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, +]; + +/// Convert a BBPE-encoded Unicode character back to its original byte value. +pub fn bbpe_char_to_byte(c: char) -> Option { + let cp = c as u32; + if (32..=126).contains(&cp) { + return Some(cp as u8); + } + for (byte_val, &mapped_cp) in BBPE_CODEPOINTS.iter().enumerate() { + if mapped_cp == cp { + return Some(byte_val as u8); + } + } + None +} diff --git a/src/decode/mod.rs b/src/decode/mod.rs index f33bf39..24d5a98 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -1,7 +1,9 @@ +pub mod bbpe; mod ctc; mod sentencepiece; pub mod tokens; +pub use bbpe::BbpeSymbolTable; pub use ctc::{ctc_greedy_decode, CtcDecoderResult}; pub use sentencepiece::sentencepiece_to_text; pub use tokens::{load_vocab, SymbolTable}; diff --git a/src/error.rs b/src/error.rs index d2dc466..a06a24c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -36,14 +36,14 @@ impl From for TranscribeError { } } -#[cfg(any(feature = "onnx", feature = "vad-silero"))] +#[cfg(any(feature = "onnx", feature = "vad-silero", feature = "punct"))] impl From for TranscribeError { fn from(e: ort::Error) -> Self { TranscribeError::Inference(e.to_string()) } } -#[cfg(any(feature = "audio-features", feature = "vad-silero"))] +#[cfg(any(feature = "audio-features", feature = "vad-silero", feature = "punct"))] impl From for TranscribeError { fn from(e: ndarray::ShapeError) -> Self { TranscribeError::Inference(e.to_string()) diff --git a/src/features/kaldi_fbank.rs b/src/features/kaldi_fbank.rs new file mode 100644 index 0000000..0274239 --- /dev/null +++ b/src/features/kaldi_fbank.rs @@ -0,0 +1,192 @@ +//! Kaldi-compatible FBank feature extraction. +//! +//! Implements the same feature pipeline used by sherpa-onnx / kaldi-native-fbank: +//! Povey window (Hamming^0.85), DC offset removal, preemphasis, power spectrum, +//! triangular mel filterbank, and natural log. The `high_freq` field follows the +//! Kaldi sign convention: negative values are treated as `nyquist + high_freq` +//! (e.g. `-400` → 7600 Hz at 16 kHz). + +use ndarray::Array2; +use rustfft::{num_complex::Complex, FftPlanner}; + +/// Kaldi-compatible FBank configuration matching sherpa-onnx / kaldi-native-fbank. +#[derive(Debug, Clone)] +pub struct KaldiFbankConfig { + pub num_bins: usize, + pub fft_size: usize, + pub window_size: usize, + pub hop_size: usize, + pub sample_rate: u32, + pub low_freq: f32, + /// Negative means nyquist + high_freq (Kaldi convention). -400 → 7600 Hz at 16kHz. + pub high_freq: f32, + pub preemph_coeff: f32, + pub snip_edges: bool, + pub remove_dc_offset: bool, +} + +impl Default for KaldiFbankConfig { + fn default() -> Self { + Self { + num_bins: 80, + fft_size: 512, + window_size: 400, + hop_size: 160, + sample_rate: 16000, + low_freq: 20.0, + high_freq: -400.0, + preemph_coeff: 0.97, + snip_edges: false, + remove_dc_offset: true, + } + } +} + +/// Compute Kaldi-compatible FBank features from audio samples. +/// +/// Returns an array of shape `[num_frames, num_bins]` (time-major). +pub fn compute_kaldi_fbank(samples: &[f32], config: &KaldiFbankConfig) -> Array2 { + let window_size = config.window_size; + let hop_size = config.hop_size; + let fft_size = config.fft_size; + let half_fft = fft_size / 2 + 1; + + if samples.is_empty() { + return Array2::zeros((0, config.num_bins)); + } + + let num_frames = if config.snip_edges { + if samples.len() < window_size { + return Array2::zeros((0, config.num_bins)); + } + (samples.len() - window_size) / hop_size + 1 + } else { + (samples.len() + hop_size / 2) / hop_size + }; + + if num_frames == 0 { + return Array2::zeros((0, config.num_bins)); + } + + let filterbank = mel_filterbank(config); + + // Povey window: hamming^0.85 + let window: Vec = (0..window_size) + .map(|i| { + let hamming = 0.54 + - 0.46 * (2.0 * std::f32::consts::PI * i as f32 / (window_size as f32 - 1.0)).cos(); + hamming.powf(0.85) + }) + .collect(); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(fft_size); + + let mut features = Vec::with_capacity(num_frames * config.num_bins); + + for frame_idx in 0..num_frames { + let center = if config.snip_edges { + frame_idx * hop_size + window_size / 2 + } else { + frame_idx * hop_size + }; + let start = center as isize - (window_size as isize / 2); + + let mut frame = vec![0.0f32; window_size]; + for (i, sample) in frame.iter_mut().enumerate() { + let idx = start + i as isize; + if idx >= 0 && (idx as usize) < samples.len() { + *sample = samples[idx as usize]; + } + } + + if config.remove_dc_offset { + let mean: f32 = frame.iter().sum::() / window_size as f32; + for s in frame.iter_mut() { + *s -= mean; + } + } + + if config.preemph_coeff > 0.0 { + for i in (1..window_size).rev() { + frame[i] -= config.preemph_coeff * frame[i - 1]; + } + frame[0] *= 1.0 - config.preemph_coeff; + } + + let mut buffer: Vec> = frame + .iter() + .zip(window.iter()) + .map(|(&s, &w)| Complex::new(s * w, 0.0)) + .collect(); + buffer.resize(fft_size, Complex::new(0.0, 0.0)); + fft.process(&mut buffer); + + let power: Vec = buffer[..half_fft].iter().map(|c| c.norm_sqr()).collect(); + + for filter in &filterbank { + let mut sum = 0.0f32; + for (i, &w) in filter.iter().enumerate() { + sum += w * power[i]; + } + features.push(if sum > f32::EPSILON { + sum.ln() + } else { + f32::EPSILON.ln() + }); + } + } + + Array2::from_shape_vec((num_frames, config.num_bins), features).unwrap() +} + +/// Build a triangular mel filterbank matrix. +/// +/// Returns a `Vec` of `num_bins` filters, each of length `fft_size / 2 + 1`. +fn mel_filterbank(config: &KaldiFbankConfig) -> Vec> { + let num_bins = config.num_bins; + let fft_size = config.fft_size; + let sample_rate = config.sample_rate as f32; + let nyquist = sample_rate / 2.0; + let low_freq = config.low_freq; + let high_freq = if config.high_freq <= 0.0 { + nyquist + config.high_freq + } else { + config.high_freq + }; + + let hz_to_mel = |hz: f32| 1127.0 * (1.0 + hz / 700.0).ln(); + let mel_to_hz = |mel: f32| 700.0 * ((mel / 1127.0).exp() - 1.0); + + let low_mel = hz_to_mel(low_freq); + let high_mel = hz_to_mel(high_freq); + + let num_points = num_bins + 2; + let mel_points: Vec = (0..num_points) + .map(|i| low_mel + (high_mel - low_mel) * i as f32 / (num_points - 1) as f32) + .collect(); + let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); + let fft_bins: Vec = hz_points + .iter() + .map(|&hz| ((hz * fft_size as f32) / sample_rate).floor() as usize) + .collect(); + + let half_fft = fft_size / 2 + 1; + let mut filterbank = vec![vec![0.0f32; half_fft]; num_bins]; + for (i, filter) in filterbank.iter_mut().enumerate() { + let left = fft_bins[i]; + let center = fft_bins[i + 1]; + let right = fft_bins[i + 2]; + if center > left { + for (idx, val) in filter[left..center.min(half_fft)].iter_mut().enumerate() { + *val = idx as f32 / (center - left) as f32; + } + } + if right > center { + for (idx, val) in filter[center..right.min(half_fft)].iter_mut().enumerate() { + *val = (right - center - idx) as f32 / (right - center) as f32; + } + } + } + filterbank +} diff --git a/src/features/mod.rs b/src/features/mod.rs index bc9d16b..c7a894c 100644 --- a/src/features/mod.rs +++ b/src/features/mod.rs @@ -1,7 +1,9 @@ mod cmvn; +pub mod kaldi_fbank; mod lfr; mod mel; pub use cmvn::apply_cmvn; +pub use kaldi_fbank::{compute_kaldi_fbank, KaldiFbankConfig}; pub use lfr::apply_lfr; pub use mel::{compute_mel, MelConfig, WindowType}; diff --git a/src/lib.rs b/src/lib.rs index deb229e..090d56f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,6 +116,9 @@ pub mod remote; #[cfg(feature = "openai")] pub use remote::RemoteTranscriptionEngine; +#[cfg(feature = "punct")] +pub mod punct; + use std::path::Path; /// Describes the capabilities of a speech model. diff --git a/src/onnx/mod.rs b/src/onnx/mod.rs index 98ee2fc..9eb7b82 100644 --- a/src/onnx/mod.rs +++ b/src/onnx/mod.rs @@ -22,5 +22,8 @@ pub enum Quantization { pub mod canary; pub mod gigaam; pub mod moonshine; +pub mod paraformer; pub mod parakeet; pub mod sense_voice; +pub mod zipformer_ctc; +pub mod zipformer_transducer; diff --git a/src/onnx/paraformer/mod.rs b/src/onnx/paraformer/mod.rs new file mode 100644 index 0000000..688682e --- /dev/null +++ b/src/onnx/paraformer/mod.rs @@ -0,0 +1,642 @@ +//! Paraformer ONNX speech recognition engine. +//! +//! Non-autoregressive end-to-end ASR model from FunASR/ModelScope. +//! Uses its own fbank feature extraction (Hamming window, dB scale), +//! LFR frame stacking, and a custom symbol table with `@@` subword joining. + +use ndarray::{Array1, Array2}; +use ort::inputs; +use ort::session::Session; +use ort::value::TensorRef; +use rustfft::{num_complex::Complex, FftPlanner}; +use std::collections::HashMap; +use std::f32::consts::PI; +use std::path::Path; + +use super::session; +use super::Quantization; +use crate::features::apply_lfr; +use crate::TranscribeError; +use crate::{ModelCapabilities, SpeechModel, TranscribeOptions, TranscriptionResult}; + +const CAPABILITIES: ModelCapabilities = ModelCapabilities { + name: "Paraformer", + engine_id: "paraformer", + sample_rate: 16000, + languages: &["zh", "en", "yue"], + supports_timestamps: false, + supports_translation: false, + supports_streaming: false, +}; + +/// Per-model inference parameters for Paraformer. +#[derive(Debug, Clone, Default)] +pub struct ParaformerParams { + /// Language hint (currently unused, Paraformer handles zh/en/yue automatically). + pub language: Option, +} + +// ---- Metadata ---- + +struct ParaformerMetadata { + lfr_window_size: usize, + lfr_window_shift: usize, + blank_id: i32, + sos_id: i32, + eos_id: i32, +} + +// ---- Symbol Table ---- + +/// Paraformer-specific symbol table with `@@` subword joining and CJK-aware spacing. +struct ParaformerSymbolTable { + id_to_sym: HashMap, +} + +impl ParaformerSymbolTable { + /// Load a tokens.txt file where each line is `symbol id`. + fn load(path: &Path) -> Result { + let content = std::fs::read_to_string(path)?; + let mut id_to_sym = HashMap::new(); + + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + // Split on last whitespace: "symbol id" + if let Some(pos) = line.rfind(|c: char| c.is_ascii_whitespace()) { + let sym = &line[..pos]; + let id_str = line[pos..].trim(); + if let Ok(id) = id_str.parse::() { + id_to_sym.insert(id, sym.to_string()); + } + } + } + + log::info!( + "Loaded Paraformer symbol table with {} tokens from {:?}", + id_to_sym.len(), + path + ); + Ok(Self { id_to_sym }) + } + + fn get(&self, id: i32) -> Option<&str> { + self.id_to_sym.get(&id).map(|s| s.as_str()) + } + + /// Decode a sequence of token IDs into text. + fn decode(&self, token_ids: &[i32]) -> String { + let mut text = String::new(); + let mut prev_join_to_next = false; + + for &id in token_ids { + let Some(sym) = self.get(id) else { + continue; + }; + if is_special_symbol(sym) { + continue; + } + + let joins_next = sym.ends_with("@@"); + let clean = sym.trim_end_matches("@@"); + if clean.is_empty() { + prev_join_to_next = joins_next; + continue; + } + + if clean.starts_with('\u{2581}') { + let piece = clean.trim_start_matches('\u{2581}'); + if !piece.is_empty() { + if !text.is_empty() && !text.ends_with(' ') { + text.push(' '); + } + text.push_str(piece); + } + prev_join_to_next = joins_next; + continue; + } + + if !text.is_empty() && !prev_join_to_next { + let prev_char = text.chars().last(); + let curr_is_ascii_word = is_ascii_word_piece(clean); + let prev_is_ascii_word = prev_char.map(is_ascii_word_char).unwrap_or(false); + let prev_is_cjk = prev_char.map(is_cjk).unwrap_or(false); + if curr_is_ascii_word && (prev_is_ascii_word || prev_is_cjk) && !text.ends_with(' ') + { + text.push(' '); + } + } + + text.push_str(clean); + prev_join_to_next = joins_next; + } + + text.trim().to_string() + } +} + +fn is_special_symbol(sym: &str) -> bool { + sym == "" + || sym == "" + || sym == "" + || sym == "" + || sym == "" + || (sym.starts_with('<') && sym.ends_with('>')) +} + +fn is_ascii_word_piece(s: &str) -> bool { + !s.is_empty() && s.chars().all(is_ascii_word_char) +} + +fn is_ascii_word_char(c: char) -> bool { + c.is_ascii_alphanumeric() +} + +fn is_cjk(c: char) -> bool { + let code = c as u32; + (0x4E00..=0x9FFF).contains(&code) + || (0x3400..=0x4DBF).contains(&code) + || (0x20000..=0x2A6DF).contains(&code) + || (0x2A700..=0x2B73F).contains(&code) + || (0x2B740..=0x2B81F).contains(&code) + || (0x2B820..=0x2CEAF).contains(&code) +} + +// ---- Feature Extraction ---- + +/// Compute Paraformer-style fbank features (Hamming window, dB scale). +/// +/// This is NOT the same as Kaldi fbank or the upstream `compute_mel()`. +/// Key differences: +/// - Standard Hamming window (no Povey modification) +/// - No preemphasis, no DC offset removal +/// - dB scale output: `10.0 * log10(energy)` with -80 dB floor +/// - snip_edges=true +/// +/// Returns [num_frames, num_bins] (80 bins by default). +fn compute_paraformer_fbank(samples: &[f32]) -> Array2 { + const NUM_BINS: usize = 80; + const FFT_SIZE: usize = 512; + const WINDOW_SIZE: usize = 400; + const HOP_SIZE: usize = 160; + const SAMPLE_RATE: f32 = 16000.0; + const LOW_FREQ: f32 = 0.0; + const HIGH_FREQ: f32 = 8000.0; + + if samples.len() < WINDOW_SIZE { + return Array2::zeros((0, NUM_BINS)); + } + + let num_frames = (samples.len() - WINDOW_SIZE) / HOP_SIZE + 1; + let num_fft_bins = FFT_SIZE / 2 + 1; + + // Standard Hamming window + let window: Vec = (0..WINDOW_SIZE) + .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f32 / (WINDOW_SIZE as f32 - 1.0)).cos()) + .collect(); + + // Mel filterbank [NUM_BINS, num_fft_bins] + let mel_banks = mel_filterbank(NUM_BINS, FFT_SIZE, SAMPLE_RATE, LOW_FREQ, HIGH_FREQ); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(FFT_SIZE); + + let mut features = Array2::zeros((num_frames, NUM_BINS)); + + for i in 0..num_frames { + let start = i * HOP_SIZE; + + // Extract and window the frame + let mut fft_input: Vec> = Vec::with_capacity(FFT_SIZE); + for j in 0..WINDOW_SIZE { + let sample = if start + j < samples.len() { + samples[start + j] + } else { + 0.0 + }; + fft_input.push(Complex::new(sample * window[j], 0.0)); + } + // Zero-pad to FFT_SIZE + fft_input.resize(FFT_SIZE, Complex::new(0.0, 0.0)); + + fft.process(&mut fft_input); + + // Power spectrum + let power_spectrum: Vec = fft_input[..num_fft_bins] + .iter() + .map(|c| c.norm_sqr()) + .collect(); + + // Apply mel filterbank and convert to dB scale + for m in 0..NUM_BINS { + let energy: f32 = mel_banks + .row(m) + .iter() + .zip(power_spectrum.iter()) + .map(|(&w, &p)| w * p) + .sum(); + + // dB scale: 10 * log10(energy), with -80 dB floor + let db = if energy < 1.0e-10 { + -80.0 + } else { + 10.0 * energy.log10() + }; + features[[i, m]] = db.max(-80.0); + } + } + + features +} + +/// Compute mel filterbank matrix of shape [num_mels, num_fft_bins]. +fn mel_filterbank( + num_mels: usize, + fft_size: usize, + sample_rate: f32, + low_freq: f32, + high_freq: f32, +) -> Array2 { + let num_fft_bins = fft_size / 2 + 1; + + let mel_low = hz_to_mel(low_freq); + let mel_high = hz_to_mel(high_freq); + + let num_points = num_mels + 2; + let mel_points: Vec = (0..num_points) + .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (num_points - 1) as f32) + .collect(); + + let hz_points: Vec = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); + + let bin_points: Vec = hz_points + .iter() + .map(|&f| f * fft_size as f32 / sample_rate) + .collect(); + + let mut banks = Array2::zeros((num_mels, num_fft_bins)); + + for m in 0..num_mels { + let left = bin_points[m]; + let center = bin_points[m + 1]; + let right = bin_points[m + 2]; + + for k in 0..num_fft_bins { + let kf = k as f32; + if kf > left && kf < center { + banks[[m, k]] = (kf - left) / (center - left); + } else if kf >= center && kf < right { + banks[[m, k]] = (right - kf) / (right - center); + } + } + } + + banks +} + +fn hz_to_mel(hz: f32) -> f32 { + 1127.0 * (1.0 + hz / 700.0).ln() +} + +fn mel_to_hz(mel: f32) -> f32 { + 700.0 * ((mel / 1127.0).exp() - 1.0) +} + +// ---- CMVN ---- + +/// Apply mean-only CMVN normalization (subtract mean, no stddev scaling). +fn apply_mean_cmvn(features: &mut Array2, mean: &Array1) { + let ncols = features.ncols(); + for mut row in features.rows_mut() { + for j in 0..ncols { + row[j] -= mean[j]; + } + } +} + +/// Load CMVN mean from an `am.mvn` file (Kaldi-style format). +/// +/// Parses the `` section and extracts the mean vector. +/// Returns `None` if the file doesn't contain the expected format. +fn load_cmvn_mean(path: &Path, target_dim: usize) -> Result>, std::io::Error> { + let content = std::fs::read_to_string(path)?; + let Some(start_idx) = content.find("") else { + return Ok(None); + }; + let rest = &content[start_idx..]; + let Some(lb_rel) = rest.find('[') else { + return Ok(None); + }; + let Some(rb_rel) = rest.find(']') else { + return Ok(None); + }; + if rb_rel <= lb_rel { + return Ok(None); + } + let body = &rest[lb_rel + 1..rb_rel]; + let mut values = Vec::new(); + for tok in body.split_whitespace() { + if let Ok(v) = tok.parse::() { + values.push(v); + } + } + if values.len() < target_dim { + return Ok(None); + } + Ok(Some(Array1::from_vec( + values.into_iter().take(target_dim).collect(), + ))) +} + +// ---- Model ---- + +pub struct ParaformerModel { + session: Session, + symbol_table: ParaformerSymbolTable, + metadata: ParaformerMetadata, + cmvn_mean: Option>, + speech_input_name: String, + speech_lengths_input_name: String, + #[allow(dead_code)] + logits_output_name: String, + #[allow(dead_code)] + token_num_output_name: Option, +} + +impl ParaformerModel { + pub fn load(model_dir: &Path, quantization: &Quantization) -> Result { + let model_path = session::resolve_model_path(model_dir, "model", quantization); + let tokens_path = model_dir.join("tokens.txt"); + let cmvn_path = model_dir.join("am.mvn"); + + if !model_path.exists() { + return Err(TranscribeError::ModelNotFound(model_path)); + } + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!("Loading Paraformer model from {:?}...", model_path); + let session = session::create_session(&model_path)?; + + // Read metadata from ONNX model + let lfr_window_size = + session::read_metadata_i32(&session, "lfr_window_size", Some(7))?.unwrap() as usize; + let lfr_window_shift = + session::read_metadata_i32(&session, "lfr_window_shift", Some(6))?.unwrap() as usize; + let blank_id = session::read_metadata_i32(&session, "blank_id", Some(0))?.unwrap(); + let sos_id = session::read_metadata_i32(&session, "sos_id", Some(1))?.unwrap(); + let eos_id = session::read_metadata_i32(&session, "eos_id", Some(2))?.unwrap(); + + let metadata = ParaformerMetadata { + lfr_window_size, + lfr_window_shift, + blank_id, + sos_id, + eos_id, + }; + + log::info!( + "Paraformer metadata: lfr_window={}x{}, blank={}, sos={}, eos={}", + lfr_window_size, + lfr_window_shift, + blank_id, + sos_id, + eos_id, + ); + + // Detect I/O names from session + let inputs: Vec = session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + let outputs: Vec = session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + let speech_input_name = inputs + .iter() + .find(|n| n.contains("speech")) + .cloned() + .unwrap_or_else(|| { + inputs + .first() + .cloned() + .unwrap_or_else(|| "speech".to_string()) + }); + + let speech_lengths_input_name = inputs + .iter() + .find(|n| n.contains("speech_lengths") || n.contains("length")) + .cloned() + .unwrap_or_else(|| { + inputs + .get(1) + .cloned() + .unwrap_or_else(|| "speech_lengths".to_string()) + }); + + let logits_output_name = outputs + .iter() + .find(|n| n.contains("logits")) + .cloned() + .unwrap_or_else(|| { + outputs + .first() + .cloned() + .unwrap_or_else(|| "logits".to_string()) + }); + + let token_num_output_name = outputs.iter().find(|n| n.contains("token_num")).cloned(); + + log::debug!( + "I/O names: speech={}, lengths={}, logits={}, token_num={:?}", + speech_input_name, + speech_lengths_input_name, + logits_output_name, + token_num_output_name, + ); + + // Load symbol table + let symbol_table = ParaformerSymbolTable::load(&tokens_path)?; + + // Load CMVN mean (LFR dim = 80 * lfr_window_size) + let lfr_dim = 80 * lfr_window_size; + let cmvn_mean = if cmvn_path.exists() { + match load_cmvn_mean(&cmvn_path, lfr_dim) { + Ok(mean) => { + if mean.is_some() { + log::info!("Loaded CMVN mean from {:?} (dim={})", cmvn_path, lfr_dim); + } + mean + } + Err(e) => { + log::warn!("Failed to load CMVN from {:?}: {}", cmvn_path, e); + None + } + } + } else { + log::debug!("No am.mvn file found at {:?}, skipping CMVN", cmvn_path); + None + }; + + Ok(Self { + session, + symbol_table, + metadata, + cmvn_mean, + speech_input_name, + speech_lengths_input_name, + logits_output_name, + token_num_output_name, + }) + } + + /// Transcribe with model-specific parameters. + pub fn transcribe_with( + &mut self, + samples: &[f32], + _params: &ParaformerParams, + ) -> Result { + self.infer(samples) + } + + fn infer(&mut self, samples: &[f32]) -> Result { + // 1. Compute Paraformer fbank features [frames, 80] + let features = compute_paraformer_fbank(samples); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + log::debug!( + "Paraformer fbank: [{}, {}]", + features.nrows(), + features.ncols() + ); + + // 2. Apply LFR + let features = apply_lfr( + &features, + self.metadata.lfr_window_size, + self.metadata.lfr_window_shift, + ); + + log::debug!("After LFR: [{}, {}]", features.nrows(), features.ncols()); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + // 3. Apply mean-only CMVN + let mut features = features; + if let Some(ref mean) = self.cmvn_mean { + apply_mean_cmvn(&mut features, mean); + } + + // 4. Forward pass + let logits = self.forward(&features)?; + + log::debug!("Logits shape: {:?}", logits.shape()); + + // 5. Decode (non-autoregressive argmax, NOT CTC) + let token_ids = self.decode_logits(&logits); + + // 6. Convert to text + let text = self.symbol_table.decode(&token_ids); + + Ok(TranscriptionResult { + text, + segments: None, + }) + } + + /// Run ONNX forward pass. Returns logits [1, T, vocab_size]. + fn forward(&mut self, features: &Array2) -> Result, TranscribeError> { + let num_frames = features.nrows(); + + // Shape: [1, T, D] + let feat_3d = + features + .to_owned() + .into_shape_with_order((1, num_frames, features.ncols()))?; + let speech_lengths = ndarray::arr1(&[num_frames as i32]); + + let feat_dyn = feat_3d.into_dyn(); + let lengths_dyn = speech_lengths.into_dyn(); + + let t_feat = TensorRef::from_array_view(feat_dyn.view())?; + let t_lengths = TensorRef::from_array_view(lengths_dyn.view())?; + + let ort_inputs = inputs![ + self.speech_input_name.as_str() => t_feat, + self.speech_lengths_input_name.as_str() => t_lengths, + ]; + + let outputs = self.session.run(ort_inputs)?; + + let logits = outputs[0].try_extract_array::()?; + let logits_owned = logits.to_owned().into_dimensionality::()?; + + Ok(logits_owned) + } + + /// Decode logits using argmax with blank/sos/eos filtering. + /// + /// Paraformer is non-autoregressive, so this is a simple argmax over the + /// vocab dimension, NOT CTC greedy decode. + fn decode_logits(&self, logits: &ndarray::Array3) -> Vec { + let blank_id = self.metadata.blank_id; + let sos_id = self.metadata.sos_id; + let eos_id = self.metadata.eos_id; + + let seq_len = logits.shape()[1]; + let mut token_ids = Vec::new(); + + for t in 0..seq_len { + // Argmax over vocab dimension + let mut best_id = 0i32; + let mut best_val = f32::NEG_INFINITY; + for (v, &val) in logits.slice(ndarray::s![0, t, ..]).iter().enumerate() { + if val > best_val { + best_val = val; + best_id = v as i32; + } + } + + // Skip blank, sos, eos + if best_id == blank_id || best_id == sos_id || best_id == eos_id { + continue; + } + + token_ids.push(best_id); + } + + token_ids + } +} + +impl SpeechModel for ParaformerModel { + fn capabilities(&self) -> ModelCapabilities { + CAPABILITIES + } + + fn transcribe_raw( + &mut self, + samples: &[f32], + _options: &TranscribeOptions, + ) -> Result { + self.infer(samples) + } +} diff --git a/src/onnx/session.rs b/src/onnx/session.rs index 676c6da..9a8ff4c 100644 --- a/src/onnx/session.rs +++ b/src/onnx/session.rs @@ -129,14 +129,14 @@ fn build_session( .commit_from_file(path)?; for input in session.inputs() { - log::info!( + log::debug!( "Model input: name={}, type={:?}", input.name(), input.dtype() ); } for output in session.outputs() { - log::info!( + log::debug!( "Model output: name={}, type={:?}", output.name(), output.dtype() diff --git a/src/onnx/zipformer_ctc/mod.rs b/src/onnx/zipformer_ctc/mod.rs new file mode 100644 index 0000000..57b812f --- /dev/null +++ b/src/onnx/zipformer_ctc/mod.rs @@ -0,0 +1,324 @@ +//! Zipformer CTC ONNX speech recognition engine. +//! +//! Supports sherpa-onnx Zipformer CTC models (e.g. from Icefall) for +//! Chinese and English transcription. Streaming models (with `cached_*` +//! inputs) are rejected at load time. + +use ndarray::Array2; +use ort::inputs; +use ort::session::Session; +use ort::value::TensorRef; +use std::path::{Path, PathBuf}; + +use super::session; +use super::Quantization; +use crate::decode::{ctc_greedy_decode, BbpeSymbolTable}; +use crate::features::{compute_kaldi_fbank, KaldiFbankConfig}; +use crate::TranscribeError; +use crate::{ModelCapabilities, SpeechModel, TranscribeOptions, TranscriptionResult}; + +const CAPABILITIES: ModelCapabilities = ModelCapabilities { + name: "Zipformer CTC", + engine_id: "zipformer_ctc", + sample_rate: 16000, + languages: &["zh", "en"], + supports_timestamps: false, + supports_translation: false, + supports_streaming: false, +}; + +/// Per-model inference parameters for Zipformer CTC. +#[derive(Debug, Clone, Default)] +pub struct ZipformerCtcParams { + /// Language hint (currently unused; the model handles zh/en automatically). + pub language: Option, +} + +// ---- Model ---- + +pub struct ZipformerCtcModel { + session: Session, + symbol_table: BbpeSymbolTable, + blank_id: i64, + x_input_name: String, + x_lens_input_name: String, + #[allow(dead_code)] + log_probs_output_name: String, + /// Index of the output that contains output lengths, if present. + log_probs_len_output_idx: Option, +} + +impl ZipformerCtcModel { + /// Load a Zipformer CTC model from `model_dir`. + /// + /// Attempts `session::resolve_model_path(dir, "model", quantization)` first, + /// then falls back to scanning the directory for any `.onnx` file (for + /// sherpa-onnx models with names like `model-epoch-34-avg-19.int8.onnx`). + pub fn load(model_dir: &Path, quantization: &Quantization) -> Result { + let model_path = Self::find_model_file(model_dir, quantization)?; + let tokens_path = model_dir.join("tokens.txt"); + + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!("Loading Zipformer CTC model from {:?}...", model_path); + let session = session::create_session(&model_path)?; + + // Reject streaming models — they have cached_* inputs + let input_names: Vec = session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + if input_names.iter().any(|n| n.starts_with("cached_")) { + return Err(TranscribeError::Config(format!( + "Streaming Zipformer models are not supported (found cached_* inputs in {:?}). \ + Use a non-streaming (offline) model.", + model_path + ))); + } + + log::debug!("Model inputs: {:?}", input_names); + + let output_names: Vec = session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + log::debug!("Model outputs: {:?}", output_names); + + // Detect input names + let x_input_name = input_names + .iter() + .find(|n| n.as_str() == "x" || n.contains("feat") || n.contains("input")) + .cloned() + .unwrap_or_else(|| { + input_names + .first() + .cloned() + .unwrap_or_else(|| "x".to_string()) + }); + + let x_lens_input_name = input_names + .iter() + .find(|n| n.contains("len") || n.contains("length")) + .cloned() + .unwrap_or_else(|| { + input_names + .get(1) + .cloned() + .unwrap_or_else(|| "x_lens".to_string()) + }); + + // Detect output names + let log_probs_output_name = output_names + .iter() + .find(|n| n.contains("log_prob") || n.contains("logit") || n.contains("prob")) + .cloned() + .unwrap_or_else(|| { + output_names + .first() + .cloned() + .unwrap_or_else(|| "log_probs".to_string()) + }); + + let log_probs_len_output_idx = output_names + .iter() + .position(|n| n.contains("len") || n.contains("length")); + + log::debug!( + "I/O mapping: x={}, x_lens={}, log_probs={}, log_probs_len_idx={:?}", + x_input_name, + x_lens_input_name, + log_probs_output_name, + log_probs_len_output_idx, + ); + + // Load BBPE symbol table (auto-detects BBPE vs BPE encoding) + let symbol_table = BbpeSymbolTable::load(&tokens_path)?; + + // blank_id is always 0 for sherpa-onnx Zipformer CTC models + let blank_id = 0i64; + + Ok(Self { + session, + symbol_table, + blank_id, + x_input_name, + x_lens_input_name, + log_probs_output_name, + log_probs_len_output_idx, + }) + } + + /// Find the ONNX model file in `model_dir`. + /// + /// Priority: + /// 1. `session::resolve_model_path(dir, "model", quantization)` — standard naming + /// 2. Scan directory for `*.int8.onnx` (when Int8 requested) or `*.onnx` + fn find_model_file( + model_dir: &Path, + quantization: &Quantization, + ) -> Result { + // Try standard path first + let standard_path = session::resolve_model_path(model_dir, "model", quantization); + if standard_path.exists() { + return Ok(standard_path); + } + + // Fallback: scan directory for onnx files + let prefer_int8 = *quantization == Quantization::Int8; + + let read_dir = std::fs::read_dir(model_dir).map_err(|e| { + TranscribeError::Io(std::io::Error::new( + e.kind(), + format!("cannot read model directory {:?}: {}", model_dir, e), + )) + })?; + + let mut int8_candidates: Vec = Vec::new(); + let mut fp32_candidates: Vec = Vec::new(); + + for entry in read_dir.flatten() { + let path = entry.path(); + if path.extension().and_then(|e| e.to_str()) == Some("onnx") { + let name = path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + if name.contains("int8") || name.contains("int4") { + int8_candidates.push(path); + } else { + fp32_candidates.push(path); + } + } + } + + // Sort for determinism + int8_candidates.sort(); + fp32_candidates.sort(); + + if prefer_int8 { + if let Some(p) = int8_candidates.into_iter().next() { + log::debug!("Found int8 model by directory scan: {:?}", p); + return Ok(p); + } + } + + if let Some(p) = fp32_candidates.into_iter().next() { + log::debug!("Found model by directory scan: {:?}", p); + return Ok(p); + } + + // Last resort: return the int8 candidate even if fp32 preferred + Err(TranscribeError::ModelNotFound(model_dir.join("model.onnx"))) + } + + /// Transcribe with model-specific parameters. + pub fn transcribe_with( + &mut self, + samples: &[f32], + _params: &ZipformerCtcParams, + ) -> Result { + self.infer(samples) + } + + fn infer(&mut self, samples: &[f32]) -> Result { + // 1. Compute Kaldi FBank features [frames, 80] + let features = compute_kaldi_fbank(samples, &KaldiFbankConfig::default()); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + log::debug!("Kaldi fbank: [{}, {}]", features.nrows(), features.ncols()); + + // 2. Run ONNX forward pass → log_probs [1, T', vocab] + let (log_probs, output_len) = self.forward(&features)?; + + log::debug!( + "log_probs shape: {:?}, output_len={}", + log_probs.shape(), + output_len + ); + + // 3. CTC greedy decode (expects [batch, time, vocab]) + let logits_lengths = vec![output_len]; + let results = ctc_greedy_decode(&log_probs.view(), &logits_lengths, self.blank_id); + + // 4. Convert token IDs (i64 → i32) and decode to text + let token_ids: Vec = results[0].tokens.iter().map(|&t| t as i32).collect(); + let text = self.symbol_table.decode(&token_ids); + + Ok(TranscriptionResult { + text, + segments: None, + }) + } + + /// Run ONNX forward pass. + /// + /// Returns `(log_probs [1, T, vocab], output_len)` where `output_len` is + /// the valid frame count for batch element 0. + fn forward( + &mut self, + features: &Array2, + ) -> Result<(ndarray::Array3, i64), TranscribeError> { + let num_frames = features.nrows() as i64; + + // Shape: [1, T, 80] + let feat_3d = + features + .to_owned() + .into_shape_with_order((1, features.nrows(), features.ncols()))?; + let x_lens = ndarray::arr1(&[num_frames]); + + let feat_dyn = feat_3d.into_dyn(); + let lens_dyn = x_lens.into_dyn(); + + let t_feat = TensorRef::from_array_view(feat_dyn.view())?; + let t_lens = TensorRef::from_array_view(lens_dyn.view())?; + + let ort_inputs = inputs![ + self.x_input_name.as_str() => t_feat, + self.x_lens_input_name.as_str() => t_lens, + ]; + + let outputs = self.session.run(ort_inputs)?; + + // Extract log_probs — always the first output, shape [1, T', vocab] + let log_probs = outputs[0].try_extract_array::()?; + let log_probs = log_probs.to_owned().into_dimensionality::()?; + + // Determine output length: use the length output if available, else T' + let output_len = if let Some(len_idx) = self.log_probs_len_output_idx { + if len_idx < outputs.len() { + let lens = outputs[len_idx].try_extract_array::()?; + lens.first().copied().unwrap_or(log_probs.shape()[1] as i64) + } else { + log_probs.shape()[1] as i64 + } + } else { + log_probs.shape()[1] as i64 + }; + + Ok((log_probs, output_len)) + } +} + +impl SpeechModel for ZipformerCtcModel { + fn capabilities(&self) -> ModelCapabilities { + CAPABILITIES + } + + fn transcribe_raw( + &mut self, + samples: &[f32], + _options: &TranscribeOptions, + ) -> Result { + self.infer(samples) + } +} diff --git a/src/onnx/zipformer_transducer/mod.rs b/src/onnx/zipformer_transducer/mod.rs new file mode 100644 index 0000000..89838b7 --- /dev/null +++ b/src/onnx/zipformer_transducer/mod.rs @@ -0,0 +1,553 @@ +//! Zipformer Transducer (RNN-T) ONNX speech recognition engine. +//! +//! Supports sherpa-onnx Zipformer Transducer models with 3 ONNX sessions +//! (encoder, decoder, joiner) and RNN-T greedy search decoding. Streaming +//! models (with `cached_*` inputs) are rejected at load time. + +use ndarray::{Array2, Array3, ArrayView1}; +use ort::inputs; +use ort::session::Session; +use ort::value::TensorRef; +use std::path::{Path, PathBuf}; + +use super::session; +use super::Quantization; +use crate::decode::BbpeSymbolTable; +use crate::features::{compute_kaldi_fbank, KaldiFbankConfig}; +use crate::TranscribeError; +use crate::{ModelCapabilities, SpeechModel, TranscribeOptions, TranscriptionResult}; + +const CAPABILITIES: ModelCapabilities = ModelCapabilities { + name: "Zipformer Transducer", + engine_id: "zipformer_transducer", + sample_rate: 16000, + languages: &["zh", "en", "vi", "ru", "ko"], + supports_timestamps: false, + supports_translation: false, + supports_streaming: false, +}; + +/// Per-model inference parameters for Zipformer Transducer. +#[derive(Debug, Clone, Default)] +pub struct ZipformerTransducerParams { + /// Language hint (currently unused; the model handles languages automatically). + pub language: Option, +} + +// ---- Model ---- + +pub struct ZipformerTransducerModel { + encoder_session: Session, + decoder_session: Session, + joiner_session: Session, + symbol_table: BbpeSymbolTable, + blank_id: i32, + context_size: usize, // always 2 + // Encoder I/O names + enc_x_name: String, + enc_x_lens_name: String, + enc_out_name: String, + enc_out_lens_name: String, + // Decoder I/O names + dec_y_name: String, + dec_out_name: String, + // Joiner I/O names + join_enc_name: String, + join_dec_name: String, + join_logit_name: String, +} + +impl ZipformerTransducerModel { + /// Load a Zipformer Transducer model from `model_dir`. + /// + /// Expects three ONNX files (encoder, decoder, joiner) and a `tokens.txt` + /// file in the model directory. + pub fn load(model_dir: &Path, quantization: &Quantization) -> Result { + let suffix = match quantization { + Quantization::FP32 => "fp32", + Quantization::FP16 => "fp16", + Quantization::Int8 => "int8", + }; + + let encoder_path = Self::find_model_file(model_dir, "encoder", suffix)?; + let decoder_path = Self::find_model_file(model_dir, "decoder", suffix)?; + let joiner_path = Self::find_model_file(model_dir, "joiner", suffix)?; + + let tokens_path = model_dir.join("tokens.txt"); + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!( + "Loading Zipformer Transducer model from {:?} (encoder={:?}, decoder={:?}, joiner={:?})...", + model_dir, + encoder_path.file_name().unwrap_or_default(), + decoder_path.file_name().unwrap_or_default(), + joiner_path.file_name().unwrap_or_default(), + ); + + let encoder_session = session::create_session(&encoder_path)?; + let decoder_session = session::create_session(&decoder_path)?; + let joiner_session = session::create_session(&joiner_path)?; + + // Reject streaming models — they have cached_* inputs + let enc_input_names: Vec = encoder_session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + if enc_input_names.iter().any(|n| n.starts_with("cached_")) { + return Err(TranscribeError::Config(format!( + "Streaming Zipformer Transducer models are not supported (found cached_* inputs in {:?}). \ + Use a non-streaming (offline) model.", + encoder_path + ))); + } + + let enc_output_names: Vec = encoder_session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + let dec_input_names: Vec = decoder_session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + let dec_output_names: Vec = decoder_session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + let join_input_names: Vec = joiner_session + .inputs() + .iter() + .map(|i| i.name().to_string()) + .collect(); + + let join_output_names: Vec = joiner_session + .outputs() + .iter() + .map(|o| o.name().to_string()) + .collect(); + + log::debug!( + "Encoder inputs: {:?}, outputs: {:?}", + enc_input_names, + enc_output_names + ); + log::debug!( + "Decoder inputs: {:?}, outputs: {:?}", + dec_input_names, + dec_output_names + ); + log::debug!( + "Joiner inputs: {:?}, outputs: {:?}", + join_input_names, + join_output_names + ); + + // Detect encoder I/O names + let enc_x_name = Self::find_name(&enc_input_names, &["x", "features", "input"]) + .unwrap_or_else(|| { + enc_input_names + .first() + .cloned() + .unwrap_or_else(|| "x".to_string()) + }); + + let enc_x_lens_name = + Self::find_name(&enc_input_names, &["x_lens", "x_length", "input_lengths"]) + .unwrap_or_else(|| { + enc_input_names + .get(1) + .cloned() + .unwrap_or_else(|| "x_lens".to_string()) + }); + + let enc_out_name = Self::find_name( + &enc_output_names, + &["encoder_out", "output", "encoder_output"], + ) + .unwrap_or_else(|| { + enc_output_names + .first() + .cloned() + .unwrap_or_else(|| "encoder_out".to_string()) + }); + + let enc_out_lens_name = Self::find_name( + &enc_output_names, + &["encoder_out_lens", "encoder_out_length", "output_lengths"], + ) + .unwrap_or_else(|| { + enc_output_names + .get(1) + .cloned() + .unwrap_or_else(|| "encoder_out_lens".to_string()) + }); + + // Detect decoder I/O names + let dec_y_name = Self::find_name(&dec_input_names, &["y", "input", "decoder_input"]) + .unwrap_or_else(|| { + dec_input_names + .first() + .cloned() + .unwrap_or_else(|| "y".to_string()) + }); + + let dec_out_name = Self::find_name( + &dec_output_names, + &["decoder_out", "output", "decoder_output"], + ) + .unwrap_or_else(|| { + dec_output_names + .first() + .cloned() + .unwrap_or_else(|| "decoder_out".to_string()) + }); + + // Detect joiner I/O names + let join_enc_name = Self::find_name( + &join_input_names, + &["encoder_out", "enc_out", "encoder_input"], + ) + .unwrap_or_else(|| { + join_input_names + .first() + .cloned() + .unwrap_or_else(|| "encoder_out".to_string()) + }); + + let join_dec_name = Self::find_name( + &join_input_names, + &["decoder_out", "dec_out", "decoder_input"], + ) + .unwrap_or_else(|| { + join_input_names + .get(1) + .cloned() + .unwrap_or_else(|| "decoder_out".to_string()) + }); + + let join_logit_name = + Self::find_name(&join_output_names, &["logit", "output", "joiner_output"]) + .unwrap_or_else(|| { + join_output_names + .first() + .cloned() + .unwrap_or_else(|| "logit".to_string()) + }); + + log::debug!( + "I/O mapping: enc({}, {} -> {}, {}), dec({} -> {}), join({}, {} -> {})", + enc_x_name, + enc_x_lens_name, + enc_out_name, + enc_out_lens_name, + dec_y_name, + dec_out_name, + join_enc_name, + join_dec_name, + join_logit_name, + ); + + // Load BBPE symbol table (auto-detects BBPE vs BPE encoding) + let symbol_table = BbpeSymbolTable::load(&tokens_path)?; + + Ok(Self { + encoder_session, + decoder_session, + joiner_session, + symbol_table, + blank_id: 0, + context_size: 2, + enc_x_name, + enc_x_lens_name, + enc_out_name, + enc_out_lens_name, + dec_y_name, + dec_out_name, + join_enc_name, + join_dec_name, + join_logit_name, + }) + } + + /// Find an ONNX model file by component name, trying various naming conventions. + /// + /// Tries in order: + /// 1. `{component}.{suffix}.onnx` + /// 2. `{component}.onnx` + /// 3. Any file starting with `{component}` ending with `.{suffix}.onnx` + /// 4. Any file starting with `{component}` ending with `.onnx` + fn find_model_file( + model_dir: &Path, + component: &str, + suffix: &str, + ) -> Result { + // 1. Try exact: {component}.{suffix}.onnx + let exact_suffixed = model_dir.join(format!("{}.{}.onnx", component, suffix)); + if exact_suffixed.exists() { + return Ok(exact_suffixed); + } + + // 2. Try exact: {component}.onnx + let exact_plain = model_dir.join(format!("{}.onnx", component)); + if exact_plain.exists() { + return Ok(exact_plain); + } + + // 3/4. Scan directory for files matching the component prefix + if let Ok(read_dir) = std::fs::read_dir(model_dir) { + let mut suffixed_candidates: Vec = Vec::new(); + let mut plain_candidates: Vec = Vec::new(); + + for entry in read_dir.flatten() { + let path = entry.path(); + let file_name = match path.file_name().and_then(|n| n.to_str()) { + Some(n) => n.to_string(), + None => continue, + }; + + if !file_name.starts_with(component) { + continue; + } + + if file_name.ends_with(&format!(".{}.onnx", suffix)) { + suffixed_candidates.push(path); + } else if file_name.ends_with(".onnx") { + plain_candidates.push(path); + } + } + + // Sort for determinism + suffixed_candidates.sort(); + plain_candidates.sort(); + + // 3. Prefer suffixed match + if let Some(p) = suffixed_candidates.into_iter().next() { + log::debug!( + "Found {} model by directory scan: {:?}", + component, + p.file_name().unwrap_or_default() + ); + return Ok(p); + } + + // 4. Fall back to any .onnx match + if let Some(p) = plain_candidates.into_iter().next() { + log::debug!( + "Found {} model by directory scan: {:?}", + component, + p.file_name().unwrap_or_default() + ); + return Ok(p); + } + } + + Err(TranscribeError::ModelNotFound(exact_suffixed)) + } + + /// Find a name from a list of candidates, returning the first match. + fn find_name(names: &[String], candidates: &[&str]) -> Option { + for &candidate in candidates { + if let Some(n) = names.iter().find(|n| n.as_str() == candidate) { + return Some(n.clone()); + } + } + None + } + + /// Transcribe with model-specific parameters. + pub fn transcribe_with( + &mut self, + samples: &[f32], + _params: &ZipformerTransducerParams, + ) -> Result { + self.infer(samples) + } + + fn infer(&mut self, samples: &[f32]) -> Result { + // 1. Compute Kaldi FBank features [frames, 80] + let features = compute_kaldi_fbank(samples, &KaldiFbankConfig::default()); + + if features.nrows() == 0 { + return Ok(TranscriptionResult { + text: String::new(), + segments: None, + }); + } + + log::debug!("Kaldi fbank: [{}, {}]", features.nrows(), features.ncols()); + + // 2. RNN-T greedy search + let token_ids = self.greedy_search(&features)?; + + // 3. Decode tokens to text + let text = self.symbol_table.decode(&token_ids); + + Ok(TranscriptionResult { + text, + segments: None, + }) + } + + /// Run encoder: features [1,T,80] + lens [1] -> encoder_out [1,T',D] + encoder_out_lens [1] + fn run_encoder( + &mut self, + features: &Array2, + ) -> Result<(Array3, i64), TranscribeError> { + let num_frames = features.nrows(); + let feat_3d = + features + .to_owned() + .into_shape_with_order((1, num_frames, features.ncols()))?; + let lens = ndarray::arr1(&[num_frames as i64]).into_dyn(); + + let feat_dyn = feat_3d.into_dyn(); + let t_feat = TensorRef::from_array_view(feat_dyn.view())?; + let t_lens = TensorRef::from_array_view(lens.view())?; + + let inputs = inputs![ + self.enc_x_name.as_str() => t_feat, + self.enc_x_lens_name.as_str() => t_lens, + ]; + let outputs = self.encoder_session.run(inputs)?; + + let encoder_out = outputs + .get(self.enc_out_name.as_str()) + .ok_or_else(|| { + TranscribeError::Inference(format!( + "encoder output '{}' not found", + self.enc_out_name + )) + })? + .try_extract_array::()? + .to_owned() + .into_dimensionality::()?; + + let encoder_out_lens = outputs + .get(self.enc_out_lens_name.as_str()) + .and_then(|v| v.try_extract_array::().ok()) + .and_then(|arr| arr.as_slice().and_then(|s| s.first().copied())) + .unwrap_or(encoder_out.shape()[1] as i64); + + Ok((encoder_out, encoder_out_lens)) + } + + /// Run decoder: y [1, context_size] (i64) -> decoder_out [1, D] + fn run_decoder(&mut self, context: &[i64]) -> Result, TranscribeError> { + let y = ndarray::Array2::from_shape_vec((1, self.context_size), context.to_vec())?; + let y_dyn = y.into_dyn(); + let t_y = TensorRef::from_array_view(y_dyn.view())?; + + let inputs = inputs![ + self.dec_y_name.as_str() => t_y, + ]; + let outputs = self.decoder_session.run(inputs)?; + + let decoder_out = outputs + .get(self.dec_out_name.as_str()) + .ok_or_else(|| { + TranscribeError::Inference(format!( + "decoder output '{}' not found", + self.dec_out_name + )) + })? + .try_extract_array::()? + .to_owned() + .into_dimensionality::()?; + + Ok(decoder_out) + } + + /// Run joiner: encoder_out_frame [1,D] + decoder_out [1,D] -> logit [1, vocab_size] + fn run_joiner( + &mut self, + encoder_out_frame: &ArrayView1, + decoder_out: &Array2, + ) -> Result, TranscribeError> { + let enc = encoder_out_frame + .to_owned() + .into_shape_with_order((1, encoder_out_frame.len()))? + .into_dyn(); + let dec_dyn = decoder_out.clone().into_dyn(); + + let t_enc = TensorRef::from_array_view(enc.view())?; + let t_dec = TensorRef::from_array_view(dec_dyn.view())?; + + let inputs = inputs![ + self.join_enc_name.as_str() => t_enc, + self.join_dec_name.as_str() => t_dec, + ]; + let outputs = self.joiner_session.run(inputs)?; + + let logit = outputs + .get(self.join_logit_name.as_str()) + .ok_or_else(|| { + TranscribeError::Inference(format!( + "joiner output '{}' not found", + self.join_logit_name + )) + })? + .try_extract_array::()? + .to_owned() + .into_dimensionality::()?; + + Ok(logit) + } + + /// Greedy search decoding for RNN-T (transducer). + fn greedy_search(&mut self, features: &Array2) -> Result, TranscribeError> { + let (encoder_out, encoder_out_lens) = self.run_encoder(features)?; + let t_max = (encoder_out_lens as usize).min(encoder_out.shape()[1]); + + let mut context = vec![self.blank_id as i64; self.context_size]; + let mut decoder_out = self.run_decoder(&context)?; + + let mut tokens = Vec::new(); + + for t in 0..t_max { + let enc_frame = encoder_out.slice(ndarray::s![0, t, ..]); + let logit = self.run_joiner(&enc_frame, &decoder_out)?; + + let logit_row = logit.row(0); + let mut max_id = 0usize; + let mut max_val = f32::NEG_INFINITY; + for (i, &v) in logit_row.iter().enumerate() { + if v > max_val { + max_val = v; + max_id = i; + } + } + + if max_id as i32 != self.blank_id { + tokens.push(max_id as i32); + context.rotate_left(1); + *context.last_mut().unwrap() = max_id as i64; + decoder_out = self.run_decoder(&context)?; + } + } + + Ok(tokens) + } +} + +impl SpeechModel for ZipformerTransducerModel { + fn capabilities(&self) -> ModelCapabilities { + CAPABILITIES + } + + fn transcribe_raw( + &mut self, + samples: &[f32], + _options: &TranscribeOptions, + ) -> Result { + self.infer(samples) + } +} diff --git a/src/punct.rs b/src/punct.rs new file mode 100644 index 0000000..97d69fd --- /dev/null +++ b/src/punct.rs @@ -0,0 +1,462 @@ +//! Neural network-based Chinese/English punctuation restoration. +//! +//! Uses a CT-Transformer ONNX model to insert punctuation into raw ASR text. +//! The model operates at the character/token level and supports both Chinese +//! (full-width punctuation) and English (ASCII punctuation) contexts. +//! +//! # Feature gate +//! +//! This module requires the `punct` feature: +//! +//! ```toml +//! [dependencies] +//! transcribe-rs = { version = "0.3", features = ["punct"] } +//! ``` +//! +//! # Model files +//! +//! You need the CT-Transformer model directory containing: +//! - `model.int8.onnx` (or `model.onnx` as fallback) +//! - `tokens.json` (vocabulary file as a JSON array of strings) +//! +//! # Usage +//! +//! ```ignore +//! use std::path::Path; +//! use transcribe_rs::punct::PunctModel; +//! +//! let mut model = PunctModel::new(Path::new("models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8/"))?; +//! let result = model.add_punctuation("今天天气很好我们去公园玩吧"); +//! println!("{}", result); +//! // => "今天天气很好,我们去公园玩吧。" +//! ``` + +use std::collections::HashMap; +use std::fs::File; +use std::path::Path; + +use ndarray::{Array1, Array2}; +use ort::inputs; +use ort::session::builder::GraphOptimizationLevel; +use ort::session::Session; +use ort::value::TensorRef; + +use crate::TranscribeError; + +// ── Punctuation type enum ──────────────────────────────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PunctType { + Underscore = 0, + Comma = 2, + Dot = 3, + Quest = 4, + Pause = 5, +} + +impl PunctType { + fn from_id(id: usize) -> Option { + match id { + 0 => Some(PunctType::Underscore), + 2 => Some(PunctType::Comma), + 3 => Some(PunctType::Dot), + 4 => Some(PunctType::Quest), + 5 => Some(PunctType::Pause), + _ => None, + } + } + + fn to_char(self) -> Option { + match self { + PunctType::Underscore => None, + PunctType::Comma => Some(','), + PunctType::Dot => Some('。'), + PunctType::Quest => Some('?'), + PunctType::Pause => Some('、'), + } + } + + fn to_ascii_char(self) -> Option { + match self { + PunctType::Underscore => None, + PunctType::Comma => Some(','), + PunctType::Dot => Some('.'), + PunctType::Quest => Some('?'), + PunctType::Pause => Some(','), + } + } +} + +// ── Token classification ───────────────────────────────────────────────────── + +#[derive(Debug, Clone)] +enum TokenInfo { + Word(String), + Char(char), + Space, + Punct(char), +} + +// ── PunctModel ──────────────────────────────────────────────────────────────── + +/// CT-Transformer ONNX-based punctuation restoration model. +/// +/// Tokenizes input text, runs inference in 20-token windows with 2-token +/// overlap, and reconstructs text with intelligently placed punctuation marks. +pub struct PunctModel { + session: Session, + token2id: HashMap, + unk_id: i32, + input_name: String, + length_name: String, +} + +impl PunctModel { + /// Load a PunctModel from a model directory. + /// + /// The directory must contain `model.int8.onnx` (or `model.onnx`) and + /// `tokens.json`. + pub fn new(model_dir: &Path) -> Result { + let model_path = model_dir.join("model.int8.onnx"); + let model_path = if !model_path.exists() { + model_dir.join("model.onnx") + } else { + model_path + }; + let tokens_path = model_dir.join("tokens.json"); + + if !model_path.exists() { + return Err(TranscribeError::ModelNotFound(model_path)); + } + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(tokens_path)); + } + + log::info!("Loading punctuation model from {:?}...", model_path); + + let session = Session::builder() + .map_err(|e| TranscribeError::Config(format!("ort session builder: {e}")))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| TranscribeError::Config(format!("ort optimization level: {e}")))? + .with_parallel_execution(true) + .map_err(|e| TranscribeError::Config(format!("ort parallel execution: {e}")))? + .commit_from_file(&model_path) + .map_err(|e| TranscribeError::Inference(format!("failed to load punct model: {e}")))?; + + let (token2id, unk_id) = Self::load_tokens(&tokens_path)?; + + // session.inputs() returns &[Outlet] — index directly + let input_name = session.inputs()[0].name().to_string(); + let length_name = session.inputs()[1].name().to_string(); + + log::debug!( + "Punct model input names: '{}' and '{}'", + input_name, + length_name + ); + + Ok(Self { + session, + token2id, + unk_id, + input_name, + length_name, + }) + } + + fn load_tokens(path: &Path) -> Result<(HashMap, i32), TranscribeError> { + let file = File::open(path)?; + let tokens: Vec = serde_json::from_reader(file)?; + let mut token2id = HashMap::new(); + for (id, token) in tokens.iter().enumerate() { + token2id.insert(token.clone(), id as i32); + } + let unk_id = *token2id.get("").unwrap_or(&0); + log::debug!("Loaded {} tokens, unk_id={}", tokens.len(), unk_id); + Ok((token2id, unk_id)) + } + + /// Tokenize input text into (token_ids, token_infos). + /// + /// - CJK characters are tokenized one character at a time. + /// - ASCII words/digit runs are kept as a single token. + /// - Whitespace is recorded as `TokenInfo::Space` (not submitted to model). + /// - Existing punctuation is recorded as `TokenInfo::Punct` and skipped. + fn tokenize(&self, text: &str) -> (Vec, Vec) { + let mut ids: Vec = Vec::new(); + let mut infos: Vec = Vec::new(); + + let mut chars = text.chars().peekable(); + while let Some(c) = chars.next() { + if c.is_whitespace() { + infos.push(TokenInfo::Space); + continue; + } + + // Existing punctuation — preserve as-is, don't tokenize + if is_existing_punct(c) { + infos.push(TokenInfo::Punct(c)); + continue; + } + + if is_cjk(c) { + // Single CJK character + let token = c.to_string(); + let id = *self.token2id.get(&token).unwrap_or(&self.unk_id); + ids.push(id); + infos.push(TokenInfo::Char(c)); + } else { + // Collect a run of non-CJK, non-space, non-punct chars as a word + let mut word = String::new(); + word.push(c); + while let Some(&nc) = chars.peek() { + if nc.is_whitespace() || is_cjk(nc) || is_existing_punct(nc) { + break; + } + word.push(nc); + chars.next(); + } + let lower = word.to_lowercase(); + let id = self + .token2id + .get(&lower) + .copied() + .unwrap_or_else(|| *self.token2id.get(&word).unwrap_or(&self.unk_id)); + ids.push(id); + infos.push(TokenInfo::Word(word)); + } + } + + (ids, infos) + } + + /// Run punctuation inference on a batch of token IDs. + /// + /// Returns a Vec of punctuation class IDs, one per input token. + fn run_inference(&mut self, token_ids: &[i32]) -> Result, TranscribeError> { + let seq_len = token_ids.len(); + let input_array = Array2::from_shape_vec((1, seq_len), token_ids.to_vec()) + .map_err(|e| TranscribeError::Inference(format!("shape error: {e}")))?; + let length_array = Array1::from_vec(vec![seq_len as i32]); + + let input_tensor = TensorRef::from_array_view(input_array.view()) + .map_err(|e| TranscribeError::Inference(format!("input tensor: {e}")))?; + let length_tensor = TensorRef::from_array_view(length_array.view()) + .map_err(|e| TranscribeError::Inference(format!("length tensor: {e}")))?; + + let outputs = self + .session + .run(inputs![ + self.input_name.as_str() => input_tensor, + self.length_name.as_str() => length_tensor + ]) + .map_err(|e| TranscribeError::Inference(format!("inference: {e}")))?; + + // Output is logits [batch=1, seq_len, num_classes] — argmax along last axis + let punct_ids: Vec = if let Ok(logits) = outputs[0].try_extract_array::() { + let shape = logits.shape(); + if shape.len() == 3 { + // [1, seq_len, num_classes] → argmax per token + let num_classes = shape[2]; + logits + .as_slice() + .unwrap() + .chunks(num_classes) + .take(seq_len) + .map(|row| { + row.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .map(|(idx, _)| idx) + .unwrap_or(0) + }) + .collect() + } else { + // Unexpected shape, return zeros + vec![0usize; seq_len] + } + } else if let Ok(output) = outputs[0].try_extract_array::() { + // Some models output pre-argmaxed int64 + output.iter().map(|&x| x as usize).collect() + } else { + let output = outputs[0] + .try_extract_array::() + .map_err(|e| TranscribeError::Inference(format!("extract output: {e}")))?; + output.iter().map(|&x| x as usize).collect() + }; + Ok(punct_ids) + } + + /// Add punctuation to raw ASR text. + /// + /// On inference error, logs a warning and returns the original text unchanged. + pub fn add_punctuation(&mut self, text: &str) -> String { + if text.is_empty() { + return text.to_string(); + } + + let (token_ids, token_infos) = self.tokenize(text); + + if token_ids.is_empty() { + return text.to_string(); + } + + // Process in windows of 20 tokens with 2-token overlap + const WINDOW: usize = 20; + const OVERLAP: usize = 2; + + let mut punctuations: Vec = vec![0usize; token_ids.len()]; + + let mut start = 0; + loop { + let end = (start + WINDOW).min(token_ids.len()); + let chunk = &token_ids[start..end]; + + match self.run_inference(chunk) { + Ok(preds) => { + // Only take the non-overlapping part (except for the last chunk) + let take = if end < token_ids.len() { + preds.len().saturating_sub(OVERLAP) + } else { + preds.len() + }; + for (i, &p) in preds[..take].iter().enumerate() { + punctuations[start + i] = p; + } + } + Err(e) => { + log::warn!("Punctuation inference failed: {}", e); + return text.to_string(); + } + } + + if end >= token_ids.len() { + break; + } + start += WINDOW - OVERLAP; + } + + self.reconstruct_with_punctuation(&token_infos, &punctuations) + } + + /// Reconstruct the output string by interleaving tokens with predicted punctuation. + fn reconstruct_with_punctuation( + &self, + token_infos: &[TokenInfo], + punctuations: &[usize], + ) -> String { + let mut result = String::new(); + let mut punct_iter = punctuations.iter(); + + for info in token_infos { + match info { + TokenInfo::Space => { + // Spaces are absorbed; punctuation takes their place + } + TokenInfo::Punct(c) => { + result.push(*c); + } + TokenInfo::Char(c) => { + result.push(*c); + if let Some(&pid) = punct_iter.next() { + if let Some(pt) = PunctType::from_id(pid) { + // CJK character → use full-width punctuation + if let Some(pc) = pt.to_char() { + result.push(pc); + } + } + } + } + TokenInfo::Word(w) => { + result.push_str(w); + if let Some(&pid) = punct_iter.next() { + if let Some(pt) = PunctType::from_id(pid) { + let pc = choose_punct_char(pt, w, &result); + if let Some(pc) = pc { + result.push(pc); + } + } + } + } + } + } + + result + } +} + +// ── Helper functions ───────────────────────────────────────────────────────── + +/// Choose full-width or ASCII punctuation based on surrounding context. +fn choose_punct_char(pt: PunctType, current_word: &str, result_so_far: &str) -> Option { + // If the current word is an English/ASCII word, use ASCII punctuation. + // If the preceding content ends in a CJK character, use full-width. + let last_meaningful = result_so_far.chars().rev().find(|c| !c.is_whitespace()); + + let use_ascii = is_english_token(current_word) + || last_meaningful + .map(|c| !is_cjk(c) && c.is_ascii_alphanumeric()) + .unwrap_or(false); + + if use_ascii { + pt.to_ascii_char() + } else { + pt.to_char() + } +} + +/// Return true if the character is a punctuation mark that should be preserved. +fn is_existing_punct(c: char) -> bool { + c.is_ascii_punctuation() + || matches!( + c, + ',' | '。' + | '?' + | '!' + | '、' + | ';' + | ':' + | '…' + | '\u{201C}' // " + | '\u{201D}' // " + | '\u{2018}' // ' + | '\u{2019}' // ' + | '—' + | '【' + | '】' + | '《' + | '》' + | '(' + | ')' + ) +} + +/// Return true if the token looks like an English/ASCII word or number. +fn is_english_token(token: &str) -> bool { + !token.is_empty() + && token + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '\'') +} + +/// Return true if the token consists entirely of CJK characters. +#[allow(dead_code)] +fn is_cjk_token(token: &str) -> bool { + !token.is_empty() && token.chars().all(is_cjk) +} + +/// Return true if the character is in a CJK Unicode block. +fn is_cjk(c: char) -> bool { + matches!(c, + '\u{4E00}'..='\u{9FFF}' // CJK Unified Ideographs + | '\u{3400}'..='\u{4DBF}' // CJK Extension A + | '\u{20000}'..='\u{2A6DF}' // CJK Extension B + | '\u{F900}'..='\u{FAFF}' // CJK Compatibility Ideographs + | '\u{2F800}'..='\u{2FA1F}' // CJK Compatibility Supplement + | '\u{3000}'..='\u{303F}' // CJK Symbols and Punctuation + | '\u{31F0}'..='\u{31FF}' // Katakana Phonetic Extensions + | '\u{3200}'..='\u{32FF}' // Enclosed CJK + | '\u{3300}'..='\u{33FF}' // CJK Compatibility + | '\u{AC00}'..='\u{D7AF}' // Hangul Syllables + ) +} diff --git a/tests/paraformer.rs b/tests/paraformer.rs new file mode 100644 index 0000000..7741d37 --- /dev/null +++ b/tests/paraformer.rs @@ -0,0 +1,31 @@ +mod common; + +use std::path::PathBuf; + +use transcribe_rs::onnx::paraformer::ParaformerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +#[test] +fn test_paraformer_transcribe() { + env_logger::init(); + + let model_dir = PathBuf::from("models/sherpa-onnx-paraformer-zh-2025-10-07"); + let wav_path = PathBuf::from("samples/zh.wav"); + + if !common::require_paths(&[&model_dir, &wav_path]) { + return; + } + + let mut model = + ParaformerModel::load(&model_dir, &Quantization::Int8).expect("Failed to load model"); + + let result = model + .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) + .expect("Failed to transcribe"); + + assert!( + !result.text.is_empty(), + "Transcription result should not be empty" + ); +} diff --git a/tests/zipformer_ctc.rs b/tests/zipformer_ctc.rs new file mode 100644 index 0000000..99a18d9 --- /dev/null +++ b/tests/zipformer_ctc.rs @@ -0,0 +1,31 @@ +mod common; + +use std::path::PathBuf; + +use transcribe_rs::onnx::zipformer_ctc::ZipformerCtcModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +#[test] +fn test_zipformer_ctc_transcribe() { + env_logger::init(); + + let model_dir = PathBuf::from("models/sherpa-onnx-zipformer-ctc-small-zh-int8-2025-07-16"); + let wav_path = PathBuf::from("samples/zh.wav"); + + if !common::require_paths(&[&model_dir, &wav_path]) { + return; + } + + let mut model = + ZipformerCtcModel::load(&model_dir, &Quantization::Int8).expect("Failed to load model"); + + let result = model + .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) + .expect("Failed to transcribe"); + + assert!( + !result.text.is_empty(), + "Transcription result should not be empty" + ); +} diff --git a/tests/zipformer_transducer.rs b/tests/zipformer_transducer.rs new file mode 100644 index 0000000..646546b --- /dev/null +++ b/tests/zipformer_transducer.rs @@ -0,0 +1,31 @@ +mod common; + +use std::path::PathBuf; + +use transcribe_rs::onnx::zipformer_transducer::ZipformerTransducerModel; +use transcribe_rs::onnx::Quantization; +use transcribe_rs::SpeechModel; + +#[test] +fn test_zipformer_transducer_transcribe() { + env_logger::init(); + + let model_dir = PathBuf::from("models/sherpa-onnx-zipformer-zh-en-2023-11-22"); + let wav_path = PathBuf::from("samples/zh.wav"); + + if !common::require_paths(&[&model_dir, &wav_path]) { + return; + } + + let mut model = ZipformerTransducerModel::load(&model_dir, &Quantization::Int8) + .expect("Failed to load model"); + + let result = model + .transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default()) + .expect("Failed to transcribe"); + + assert!( + !result.text.is_empty(), + "Transcription result should not be empty" + ); +}