diff --git a/.gitignore b/.gitignore index 561d3b4..e12793f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea target/ *.onnx *.onnx.* @@ -7,6 +8,7 @@ venv/ vits-piper* sherpa-onnx/ *.wav +*.vocab !samples/*.wav *.bz2 sherpa-onnx-whisper* diff --git a/crates/sherpa-rs/Cargo.toml b/crates/sherpa-rs/Cargo.toml index 7eb5e37..7c7cc4a 100644 --- a/crates/sherpa-rs/Cargo.toml +++ b/crates/sherpa-rs/Cargo.toml @@ -109,4 +109,12 @@ path = "../../examples/sense_voice.rs" [[example]] name = "paraformer" -path = "../../examples/paraformer.rs" \ No newline at end of file +path = "../../examples/paraformer.rs" + +[[example]] +name = "transducer" +path = "../../examples/transducer.rs" + +[[example]] +name = "transducer_vosk" +path = "../../examples/transducer_vosk.rs" \ No newline at end of file diff --git a/crates/sherpa-rs/src/lib.rs b/crates/sherpa-rs/src/lib.rs index 0dc55f1..d585dc7 100644 --- a/crates/sherpa-rs/src/lib.rs +++ b/crates/sherpa-rs/src/lib.rs @@ -8,6 +8,7 @@ pub mod paraformer; pub mod punctuate; pub mod sense_voice; pub mod speaker_id; +pub mod transducer; pub mod vad; pub mod whisper; pub mod zipformer; diff --git a/crates/sherpa-rs/src/transducer.rs b/crates/sherpa-rs/src/transducer.rs new file mode 100644 index 0000000..c8c5f6a --- /dev/null +++ b/crates/sherpa-rs/src/transducer.rs @@ -0,0 +1,152 @@ +use crate::utils::cstr_to_string; +use crate::{get_default_provider, utils::cstring_from_str}; +use eyre::{bail, Result}; +use std::mem; + +pub struct TransducerRecognizer { + recognizer: *const sherpa_rs_sys::SherpaOnnxOfflineRecognizer, +} + +#[derive(Debug, Clone)] +pub struct TransducerConfig { + pub decoder: String, + pub encoder: String, + pub joiner: String, + pub tokens: String, + pub num_threads: i32, + pub sample_rate: i32, + pub feature_dim: i32, + pub decoding_method: String, + pub hotwords_file: String, + pub hotwords_score: f32, + pub modeling_unit: String, + pub bpe_vocab: String, + pub blank_penalty: f32, + pub debug: bool, + pub provider: Option, +} + +impl Default for TransducerConfig { + fn default() -> Self { + TransducerConfig { + decoder: String::new(), + encoder: String::new(), + joiner: String::new(), + tokens: String::new(), + num_threads: 1, + sample_rate: 0, + feature_dim: 0, + decoding_method: String::new(), + hotwords_file: String::new(), + hotwords_score: 0.0, + modeling_unit: String::new(), + bpe_vocab: String::new(), + blank_penalty: 0.0, + debug: false, + provider: None, + } + } +} + +impl TransducerRecognizer { + pub fn new(config: TransducerConfig) -> Result { + let recognizer = unsafe { + let debug = config.debug.into(); + let provider = config.provider.unwrap_or(get_default_provider()); + let provider_ptr = cstring_from_str(&provider); + + let encoder = cstring_from_str(&config.encoder); + let decoder = cstring_from_str(&config.decoder); + let joiner = cstring_from_str(&config.joiner); + let model_type = cstring_from_str("transducer"); + let modeling_unit = cstring_from_str(&config.modeling_unit); + let bpe_vocab = cstring_from_str(&config.bpe_vocab); + let hotwords_file = cstring_from_str(&config.hotwords_file); + let tokens = cstring_from_str(&config.tokens); + let decoding_method = cstring_from_str(&config.decoding_method); + + let offline_model_config = sherpa_rs_sys::SherpaOnnxOfflineModelConfig { + transducer: sherpa_rs_sys::SherpaOnnxOfflineTransducerModelConfig { + encoder: encoder.as_ptr(), + decoder: decoder.as_ptr(), + joiner: joiner.as_ptr(), + }, + tokens: tokens.as_ptr(), + num_threads: config.num_threads, + debug, + provider: provider_ptr.as_ptr(), + model_type: model_type.as_ptr(), + modeling_unit: modeling_unit.as_ptr(), + bpe_vocab: bpe_vocab.as_ptr(), + + // NULLs + telespeech_ctc: mem::zeroed::<_>(), + paraformer: mem::zeroed::<_>(), + tdnn: mem::zeroed::<_>(), + nemo_ctc: mem::zeroed::<_>(), + whisper: mem::zeroed::<_>(), + sense_voice: mem::zeroed::<_>(), + moonshine: mem::zeroed::<_>(), + fire_red_asr: mem::zeroed::<_>(), + }; + + let recognizer_config = sherpa_rs_sys::SherpaOnnxOfflineRecognizerConfig { + model_config: offline_model_config, + feat_config: sherpa_rs_sys::SherpaOnnxFeatureConfig { + sample_rate: config.sample_rate, + feature_dim: config.feature_dim, + }, + hotwords_file: hotwords_file.as_ptr(), + blank_penalty: config.blank_penalty, + decoding_method: decoding_method.as_ptr(), + hotwords_score: config.hotwords_score, + + // NULLs + lm_config: mem::zeroed::<_>(), + rule_fsts: mem::zeroed::<_>(), + rule_fars: mem::zeroed::<_>(), + max_active_paths: mem::zeroed::<_>(), + }; + + let recognizer = sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&recognizer_config); + if recognizer.is_null() { + bail!("SherpaOnnxCreateOfflineRecognizer failed"); + } + recognizer + }; + + Ok(Self { recognizer }) + } + + pub fn transcribe(&mut self, sample_rate: u32, samples: &[f32]) -> String { + unsafe { + let stream = sherpa_rs_sys::SherpaOnnxCreateOfflineStream(self.recognizer); + sherpa_rs_sys::SherpaOnnxAcceptWaveformOffline( + stream, + sample_rate as i32, + samples.as_ptr(), + samples.len().try_into().unwrap(), + ); + sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream); + let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream); + let raw_result = result_ptr.read(); + let text = cstr_to_string(raw_result.text as _); + + // Free + sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr); + sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream); + text + } + } +} + +unsafe impl Send for TransducerRecognizer {} +unsafe impl Sync for TransducerRecognizer {} + +impl Drop for TransducerRecognizer { + fn drop(&mut self) { + unsafe { + sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizer(self.recognizer); + } + } +} diff --git a/examples/transducer.rs b/examples/transducer.rs new file mode 100644 index 0000000..ad861b5 --- /dev/null +++ b/examples/transducer.rs @@ -0,0 +1,44 @@ +use sherpa_rs::read_audio_file; +use sherpa_rs::transducer::{TransducerConfig, TransducerRecognizer}; +use std::time::Instant; + +/* +wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/decoder-epoch-90-avg-20.onnx +wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/encoder-epoch-90-avg-20.onnx +wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/joiner-epoch-90-avg-20.onnx +wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/tokens.txt +wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav +cargo run --example transducer motivation.wav +*/ + +pub fn main() { + let path = std::env::args().nth(1).expect("Missing file path argument"); + let (samples, sample_rate) = read_audio_file(&path).unwrap(); + + // Check if the sample rate is 16000 + if sample_rate != 16000 { + panic!("The sample rate must be 16000."); + } + + let config = TransducerConfig { + decoder: "decoder-epoch-90-avg-20.onnx".to_string(), + encoder: "encoder-epoch-90-avg-20.onnx".to_string(), + joiner: "joiner-epoch-90-avg-20.onnx".to_string(), + tokens: "tokens.txt".to_string(), + num_threads: 1, + sample_rate: 16_000, + feature_dim: 80, + debug: true, + ..Default::default() + }; + + let mut recognizer = TransducerRecognizer::new(config).unwrap(); + + let start_t = Instant::now(); + let result = recognizer.transcribe(sample_rate, &samples); + let lower_case = result.to_lowercase(); + let trimmed_result = lower_case.trim(); + + println!("Time taken for decode: {:?}", start_t.elapsed()); + println!("Transcribe result: {:?}", trimmed_result); +} diff --git a/examples/transducer_vosk.rs b/examples/transducer_vosk.rs new file mode 100644 index 0000000..0514234 --- /dev/null +++ b/examples/transducer_vosk.rs @@ -0,0 +1,50 @@ +use sherpa_rs::read_audio_file; +use sherpa_rs::transducer::{TransducerConfig, TransducerRecognizer}; +use std::time::Instant; + +/* +wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/am-onnx/decoder.onnx +wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/am-onnx/encoder.onnx +wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/am-onnx/joiner.onnx +wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/lang/tokens.txt +wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/lang/unigram_500.vocab +wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/test.wav +touch hotwords.txt +cargo run --example transducer_vosk test.wav +*/ +pub fn main() { + let path = std::env::args().nth(1).expect("Missing file path argument"); + let (samples, sample_rate) = read_audio_file(&path).unwrap(); + + // Check if the sample rate is 16000 + if sample_rate != 16000 { + panic!("The sample rate must be 16000."); + } + + let config = TransducerConfig { + decoder: "decoder.onnx".to_string(), + encoder: "encoder.onnx".to_string(), + joiner: "joiner.onnx".to_string(), + tokens: "tokens.txt".to_string(), + bpe_vocab: "unigram_500.vocab".to_string(), + hotwords_file: "hotwords.txt".to_string(), + hotwords_score: 1.2, + num_threads: 1, + sample_rate: 16_000, + feature_dim: 80, + modeling_unit: "bpe".to_string(), + decoding_method: "modified_beam_search".to_string(), + debug: true, + ..Default::default() + }; + + let mut recognizer = TransducerRecognizer::new(config).unwrap(); + + let start_t = Instant::now(); + let result = recognizer.transcribe(sample_rate, &samples); + let lower_case = result.to_lowercase(); + let trimmed_result = lower_case.trim(); + + println!("Time taken for decode: {:?}", start_t.elapsed()); + println!("Transcribe result: {:?}", trimmed_result); +}