Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
target/
*.onnx
*.onnx.*
Expand All @@ -7,6 +8,7 @@ venv/
vits-piper*
sherpa-onnx/
*.wav
*.vocab
!samples/*.wav
*.bz2
sherpa-onnx-whisper*
Expand Down
10 changes: 9 additions & 1 deletion crates/sherpa-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,12 @@ path = "../../examples/sense_voice.rs"

[[example]]
name = "paraformer"
path = "../../examples/paraformer.rs"
path = "../../examples/paraformer.rs"

[[example]]
name = "transducer"
path = "../../examples/transducer.rs"

[[example]]
name = "transducer_vosk"
path = "../../examples/transducer_vosk.rs"
1 change: 1 addition & 0 deletions crates/sherpa-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod paraformer;
pub mod punctuate;
pub mod sense_voice;
pub mod speaker_id;
pub mod transducer;
pub mod vad;
pub mod whisper;
pub mod zipformer;
Expand Down
152 changes: 152 additions & 0 deletions crates/sherpa-rs/src/transducer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use crate::utils::cstr_to_string;
use crate::{get_default_provider, utils::cstring_from_str};
use eyre::{bail, Result};
use std::mem;

pub struct TransducerRecognizer {
recognizer: *const sherpa_rs_sys::SherpaOnnxOfflineRecognizer,
}

#[derive(Debug, Clone)]
pub struct TransducerConfig {
pub decoder: String,
pub encoder: String,
pub joiner: String,
pub tokens: String,
pub num_threads: i32,
pub sample_rate: i32,
pub feature_dim: i32,
pub decoding_method: String,
pub hotwords_file: String,
pub hotwords_score: f32,
pub modeling_unit: String,
pub bpe_vocab: String,
pub blank_penalty: f32,
pub debug: bool,
pub provider: Option<String>,
}

impl Default for TransducerConfig {
fn default() -> Self {
TransducerConfig {
decoder: String::new(),
encoder: String::new(),
joiner: String::new(),
tokens: String::new(),
num_threads: 1,
sample_rate: 0,
feature_dim: 0,
decoding_method: String::new(),
hotwords_file: String::new(),
hotwords_score: 0.0,
modeling_unit: String::new(),
bpe_vocab: String::new(),
blank_penalty: 0.0,
debug: false,
provider: None,
}
}
}

impl TransducerRecognizer {
pub fn new(config: TransducerConfig) -> Result<Self> {
let recognizer = unsafe {
let debug = config.debug.into();
let provider = config.provider.unwrap_or(get_default_provider());
let provider_ptr = cstring_from_str(&provider);

let encoder = cstring_from_str(&config.encoder);
let decoder = cstring_from_str(&config.decoder);
let joiner = cstring_from_str(&config.joiner);
let model_type = cstring_from_str("transducer");
let modeling_unit = cstring_from_str(&config.modeling_unit);
let bpe_vocab = cstring_from_str(&config.bpe_vocab);
let hotwords_file = cstring_from_str(&config.hotwords_file);
let tokens = cstring_from_str(&config.tokens);
let decoding_method = cstring_from_str(&config.decoding_method);

let offline_model_config = sherpa_rs_sys::SherpaOnnxOfflineModelConfig {
transducer: sherpa_rs_sys::SherpaOnnxOfflineTransducerModelConfig {
encoder: encoder.as_ptr(),
decoder: decoder.as_ptr(),
joiner: joiner.as_ptr(),
},
tokens: tokens.as_ptr(),
num_threads: config.num_threads,
debug,
provider: provider_ptr.as_ptr(),
model_type: model_type.as_ptr(),
modeling_unit: modeling_unit.as_ptr(),
bpe_vocab: bpe_vocab.as_ptr(),

// NULLs
telespeech_ctc: mem::zeroed::<_>(),
paraformer: mem::zeroed::<_>(),
tdnn: mem::zeroed::<_>(),
nemo_ctc: mem::zeroed::<_>(),
whisper: mem::zeroed::<_>(),
sense_voice: mem::zeroed::<_>(),
moonshine: mem::zeroed::<_>(),
fire_red_asr: mem::zeroed::<_>(),
};

let recognizer_config = sherpa_rs_sys::SherpaOnnxOfflineRecognizerConfig {
model_config: offline_model_config,
feat_config: sherpa_rs_sys::SherpaOnnxFeatureConfig {
sample_rate: config.sample_rate,
feature_dim: config.feature_dim,
},
hotwords_file: hotwords_file.as_ptr(),
blank_penalty: config.blank_penalty,
decoding_method: decoding_method.as_ptr(),
hotwords_score: config.hotwords_score,

// NULLs
lm_config: mem::zeroed::<_>(),
rule_fsts: mem::zeroed::<_>(),
rule_fars: mem::zeroed::<_>(),
max_active_paths: mem::zeroed::<_>(),
};

let recognizer = sherpa_rs_sys::SherpaOnnxCreateOfflineRecognizer(&recognizer_config);
if recognizer.is_null() {
bail!("SherpaOnnxCreateOfflineRecognizer failed");
}
recognizer
};

Ok(Self { recognizer })
}

pub fn transcribe(&mut self, sample_rate: u32, samples: &[f32]) -> String {
unsafe {
let stream = sherpa_rs_sys::SherpaOnnxCreateOfflineStream(self.recognizer);
sherpa_rs_sys::SherpaOnnxAcceptWaveformOffline(
stream,
sample_rate as i32,
samples.as_ptr(),
samples.len().try_into().unwrap(),
);
sherpa_rs_sys::SherpaOnnxDecodeOfflineStream(self.recognizer, stream);
let result_ptr = sherpa_rs_sys::SherpaOnnxGetOfflineStreamResult(stream);
let raw_result = result_ptr.read();
let text = cstr_to_string(raw_result.text as _);

// Free
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizerResult(result_ptr);
sherpa_rs_sys::SherpaOnnxDestroyOfflineStream(stream);
text
}
}
}

unsafe impl Send for TransducerRecognizer {}
unsafe impl Sync for TransducerRecognizer {}

impl Drop for TransducerRecognizer {
fn drop(&mut self) {
unsafe {
sherpa_rs_sys::SherpaOnnxDestroyOfflineRecognizer(self.recognizer);
}
}
}
44 changes: 44 additions & 0 deletions examples/transducer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use sherpa_rs::read_audio_file;
use sherpa_rs::transducer::{TransducerConfig, TransducerRecognizer};
use std::time::Instant;

/*
wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/decoder-epoch-90-avg-20.onnx
wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/encoder-epoch-90-avg-20.onnx
wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/joiner-epoch-90-avg-20.onnx
wget https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-libriheavy-20230926-small/resolve/main/tokens.txt
wget https://github.com/thewh1teagle/sherpa-rs/releases/download/v0.1.0/motivation.wav -O motivation.wav
cargo run --example transducer motivation.wav
*/

pub fn main() {
let path = std::env::args().nth(1).expect("Missing file path argument");
let (samples, sample_rate) = read_audio_file(&path).unwrap();

// Check if the sample rate is 16000
if sample_rate != 16000 {
panic!("The sample rate must be 16000.");
}

let config = TransducerConfig {
decoder: "decoder-epoch-90-avg-20.onnx".to_string(),
encoder: "encoder-epoch-90-avg-20.onnx".to_string(),
joiner: "joiner-epoch-90-avg-20.onnx".to_string(),
tokens: "tokens.txt".to_string(),
num_threads: 1,
sample_rate: 16_000,
feature_dim: 80,
debug: true,
..Default::default()
};

let mut recognizer = TransducerRecognizer::new(config).unwrap();

let start_t = Instant::now();
let result = recognizer.transcribe(sample_rate, &samples);
let lower_case = result.to_lowercase();
let trimmed_result = lower_case.trim();

println!("Time taken for decode: {:?}", start_t.elapsed());
println!("Transcribe result: {:?}", trimmed_result);
}
50 changes: 50 additions & 0 deletions examples/transducer_vosk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use sherpa_rs::read_audio_file;
use sherpa_rs::transducer::{TransducerConfig, TransducerRecognizer};
use std::time::Instant;

/*
wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/am-onnx/decoder.onnx
wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/am-onnx/encoder.onnx
wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/am-onnx/joiner.onnx
wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/lang/tokens.txt
wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/lang/unigram_500.vocab
wget https://huggingface.co/alphacep/vosk-model-ru/resolve/main/test.wav
touch hotwords.txt
cargo run --example transducer_vosk test.wav
*/
pub fn main() {
let path = std::env::args().nth(1).expect("Missing file path argument");
let (samples, sample_rate) = read_audio_file(&path).unwrap();

// Check if the sample rate is 16000
if sample_rate != 16000 {
panic!("The sample rate must be 16000.");
}

let config = TransducerConfig {
decoder: "decoder.onnx".to_string(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thewh1teagle Hi, do you know if there is any way to pass &[u8] instead of a file path to the model?
Currently, we use a workaround with a memfd_create call to get a file descriptor for the memory and then pass the path to that descriptor.

However, ONNX Runtime provides a more convenient API, for example:

Session::builder()
   ....
   .commit_from_memory(model);

encoder: "encoder.onnx".to_string(),
joiner: "joiner.onnx".to_string(),
tokens: "tokens.txt".to_string(),
bpe_vocab: "unigram_500.vocab".to_string(),
hotwords_file: "hotwords.txt".to_string(),
hotwords_score: 1.2,
num_threads: 1,
sample_rate: 16_000,
feature_dim: 80,
modeling_unit: "bpe".to_string(),
decoding_method: "modified_beam_search".to_string(),
debug: true,
..Default::default()
};

let mut recognizer = TransducerRecognizer::new(config).unwrap();

let start_t = Instant::now();
let result = recognizer.transcribe(sample_rate, &samples);
let lower_case = result.to_lowercase();
let trimmed_result = lower_case.trim();

println!("Time taken for decode: {:?}", start_t.elapsed());
println!("Transcribe result: {:?}", trimmed_result);
}
Loading