Skip to content

Commit 9537018

Browse files
committed
Enhance Paraformer support and add streaming example
- Updated `.gitignore` to include new Paraformer models. - Added `serde` and `serde_json` dependencies in `Cargo.toml` and `Cargo.lock`. - Introduced `OnlineRecognizerJsonResult` and `OnlineRecognizerResult` structs for handling recognition results. - Implemented `ParaformerOnlineRecognizer` with configuration options and transcription capabilities. - Added a new example for streaming transcription using Paraformer.
1 parent 8bb029e commit 9537018

File tree

6 files changed

+286
-0
lines changed

6 files changed

+286
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ sherpa-onnx-whisper*
1515
sherpa-onnx-pyannote-*
1616
sherpa-onnx-zipformer-*
1717
sherpa-onnx-moonshine-*
18+
sherpa-onnx-paraformer-*
19+
sherpa-onnx-streaming-paraformer-*
20+
sherpa-onnx-sense-voice-*
1821
*.txt
1922
!checksum.txt
2023
sherpa-onnx-v*

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/sherpa-rs/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ eyre = "0.6.12"
2323
hound = { version = "3.5.1" }
2424
sherpa-rs-sys = { path = "../sherpa-rs-sys", version = "0.6.8", default-features = false }
2525
tracing = "0.1.40"
26+
serde = { version = "1.0.216", features = ["derive"] }
27+
serde_json = "1.0.134"
2628

2729
[dev-dependencies]
2830
clap = { version = "4.5.8", features = ["derive"] }
@@ -139,3 +141,7 @@ path = "../../examples/dolphin.rs"
139141
[[example]]
140142
name = "parakeet"
141143
path = "../../examples/parakeet.rs"
144+
145+
[[example]]
146+
name = "paraformer_streaming"
147+
path = "../../examples/paraformer_streaming.rs"

crates/sherpa-rs/src/lib.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod tts;
2222

2323
use std::ffi::CStr;
2424

25+
use serde::{Deserialize, Serialize};
2526
#[cfg(feature = "sys")]
2627
pub use sherpa_rs_sys;
2728

@@ -135,3 +136,76 @@ impl Default for OnnxConfig {
135136
}
136137
}
137138
}
139+
140+
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
141+
pub struct OnlineRecognizerJsonResult {
142+
pub text: String,
143+
pub tokens: Vec<String>,
144+
pub timestamps: Vec<f32>,
145+
pub ys_probs: Vec<f32>,
146+
pub lm_probs: Vec<f32>,
147+
pub context_scores: Vec<f32>,
148+
pub segment: i32,
149+
pub words: Vec<Word>,
150+
pub start_time: f32,
151+
pub is_final: bool,
152+
pub is_eof: bool,
153+
}
154+
155+
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
156+
pub struct Word {
157+
pub word: String,
158+
pub start: f32,
159+
pub end: f32,
160+
}
161+
162+
#[derive(Debug, Clone)]
163+
pub struct OnlineRecognizerResult {
164+
pub text: String,
165+
pub timestamps: Vec<f32>,
166+
pub tokens: Vec<String>,
167+
/// Whether the result is final
168+
pub is_final: bool,
169+
/// segment id
170+
pub segment: i32,
171+
/// start_time of the segment in seconds
172+
pub start_time: f32,
173+
}
174+
175+
impl OnlineRecognizerResult {
176+
fn new(result: &sherpa_rs_sys::SherpaOnnxOnlineRecognizerResult) -> Self {
177+
let text = unsafe { cstr_to_string(result.text) };
178+
let count = result.count.try_into().unwrap();
179+
let timestamps = if result.timestamps.is_null() {
180+
Vec::new()
181+
} else {
182+
unsafe { std::slice::from_raw_parts(result.timestamps, count).to_vec() }
183+
};
184+
let mut tokens = Vec::with_capacity(count);
185+
let mut next_token = result.tokens;
186+
let json_str = unsafe { cstr_to_string(result.json) };
187+
let json: OnlineRecognizerJsonResult = serde_json::from_str(&json_str).unwrap_or_default();
188+
189+
for _ in 0..count {
190+
let token = unsafe { CStr::from_ptr(next_token) };
191+
tokens.push(token.to_string_lossy().into_owned());
192+
next_token = next_token
193+
.wrapping_byte_offset(token.to_bytes_with_nul().len().try_into().unwrap());
194+
}
195+
196+
Self {
197+
text,
198+
timestamps,
199+
tokens,
200+
is_final: json.is_final,
201+
segment: json.segment,
202+
start_time: json.start_time,
203+
}
204+
}
205+
}
206+
207+
impl From<&sherpa_rs_sys::SherpaOnnxOnlineRecognizerResult> for OnlineRecognizerResult {
208+
fn from(value: &sherpa_rs_sys::SherpaOnnxOnlineRecognizerResult) -> Self {
209+
Self::new(value)
210+
}
211+
}

crates/sherpa-rs/src/paraformer.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub struct ParaformerRecognizer {
88
}
99

1010
pub type ParaformerRecognizerResult = super::OfflineRecognizerResult;
11+
pub type ParaformerOnlineRecognizerResult = super::OnlineRecognizerResult;
1112

1213
#[derive(Debug, Clone)]
1314
pub struct ParaformerConfig {
@@ -137,3 +138,154 @@ impl Drop for ParaformerRecognizer {
137138
}
138139
}
139140
}
141+
142+
#[derive(Debug, Clone)]
143+
pub struct ParaformerOnlineConfig {
144+
pub encoder_model_path: String,
145+
pub decoder_model_path: String,
146+
pub tokens: String,
147+
pub provider: Option<String>,
148+
pub num_threads: Option<i32>,
149+
pub debug: bool,
150+
pub enable_endpoint: Option<bool>,
151+
pub rule1_min_trailing_silence: Option<f32>,
152+
pub rule2_min_trailing_silence: Option<f32>,
153+
pub rule3_min_utterance_length: Option<f32>,
154+
}
155+
156+
impl Default for ParaformerOnlineConfig {
157+
fn default() -> Self {
158+
Self {
159+
encoder_model_path: String::new(),
160+
decoder_model_path: String::new(),
161+
tokens: String::new(),
162+
provider: None,
163+
num_threads: None,
164+
debug: false,
165+
enable_endpoint: None,
166+
rule1_min_trailing_silence: None,
167+
rule2_min_trailing_silence: None,
168+
rule3_min_utterance_length: None,
169+
}
170+
}
171+
}
172+
173+
#[derive(Debug)]
174+
pub struct ParaformerOnlineRecognizer {
175+
recognizer: *const sherpa_rs_sys::SherpaOnnxOnlineRecognizer,
176+
stream: *const sherpa_rs_sys::SherpaOnnxOnlineStream,
177+
segment_id: i32,
178+
}
179+
180+
impl ParaformerOnlineRecognizer {
181+
pub fn new(config: ParaformerOnlineConfig) -> Result<Self> {
182+
let debug = config.debug.into();
183+
let provider = config.provider.unwrap_or(get_default_provider());
184+
let provider_ptr = cstring_from_str(&provider);
185+
let tokens_ptr = cstring_from_str(&config.tokens);
186+
187+
let encoder_model_path = if config.encoder_model_path.is_empty() {
188+
bail!("Encoder model path is required for online Paraformer")
189+
} else {
190+
cstring_from_str(&config.encoder_model_path)
191+
};
192+
let decoder_model_path = if config.decoder_model_path.is_empty() {
193+
bail!("Decoder model path is required for online Paraformer")
194+
} else {
195+
cstring_from_str(&config.decoder_model_path)
196+
};
197+
let paraformer_config = sherpa_rs_sys::SherpaOnnxOnlineParaformerModelConfig {
198+
encoder: encoder_model_path.as_ptr(),
199+
decoder: decoder_model_path.as_ptr(),
200+
};
201+
let empty_str = cstring_from_str("");
202+
let mut model_config = sherpa_rs_sys::SherpaOnnxOnlineModelConfig::default();
203+
model_config.debug = debug;
204+
model_config.num_threads = config.num_threads.unwrap_or(1);
205+
model_config.provider = provider_ptr.as_ptr();
206+
model_config.tokens = tokens_ptr.as_ptr();
207+
model_config.paraformer = paraformer_config;
208+
209+
// Recognizer config
210+
let mut recognizer_config = sherpa_rs_sys::SherpaOnnxOnlineRecognizerConfig::default();
211+
recognizer_config.feat_config = sherpa_rs_sys::SherpaOnnxFeatureConfig {
212+
sample_rate: 16000,
213+
feature_dim: 80,
214+
};
215+
recognizer_config.model_config = model_config;
216+
recognizer_config.rule_fsts = empty_str.as_ptr();
217+
recognizer_config.rule_fars = empty_str.as_ptr();
218+
219+
recognizer_config.enable_endpoint = config.enable_endpoint.unwrap_or(false).into();
220+
recognizer_config.rule1_min_trailing_silence =
221+
config.rule1_min_trailing_silence.unwrap_or(2.4);
222+
recognizer_config.rule2_min_trailing_silence =
223+
config.rule2_min_trailing_silence.unwrap_or(1.2);
224+
recognizer_config.rule3_min_utterance_length =
225+
config.rule3_min_utterance_length.unwrap_or(300.0);
226+
227+
let recognizer =
228+
unsafe { sherpa_rs_sys::SherpaOnnxCreateOnlineRecognizer(&recognizer_config) };
229+
if recognizer.is_null() {
230+
bail!("Failed to create online Paraformer recognizer");
231+
}
232+
let stream = unsafe { sherpa_rs_sys::SherpaOnnxCreateOnlineStream(recognizer) };
233+
if stream.is_null() {
234+
unsafe {
235+
sherpa_rs_sys::SherpaOnnxDestroyOnlineRecognizer(recognizer);
236+
}
237+
bail!("Failed to create online Paraformer stream");
238+
}
239+
Ok(Self {
240+
recognizer,
241+
stream,
242+
segment_id: 0,
243+
})
244+
}
245+
246+
pub fn transcribe(
247+
&mut self,
248+
sample_rate: u32,
249+
samples: &[f32],
250+
) -> ParaformerOnlineRecognizerResult {
251+
unsafe {
252+
sherpa_rs_sys::SherpaOnnxOnlineStreamAcceptWaveform(
253+
self.stream,
254+
sample_rate as i32,
255+
samples.as_ptr(),
256+
samples.len() as i32,
257+
);
258+
259+
while sherpa_rs_sys::SherpaOnnxIsOnlineStreamReady(self.recognizer, self.stream) == 1 {
260+
sherpa_rs_sys::SherpaOnnxDecodeOnlineStream(self.recognizer, self.stream);
261+
}
262+
263+
let result_ptr =
264+
sherpa_rs_sys::SherpaOnnxGetOnlineStreamResult(self.recognizer, self.stream);
265+
let raw_result = result_ptr.read();
266+
let mut result = ParaformerOnlineRecognizerResult::from(&raw_result);
267+
sherpa_rs_sys::SherpaOnnxDestroyOnlineRecognizerResult(result_ptr);
268+
269+
if sherpa_rs_sys::SherpaOnnxOnlineStreamIsEndpoint(self.recognizer, self.stream) == 1 {
270+
self.segment_id += 1;
271+
sherpa_rs_sys::SherpaOnnxOnlineStreamReset(self.recognizer, self.stream);
272+
result.is_final = true;
273+
}
274+
275+
result.segment = self.segment_id;
276+
result
277+
}
278+
}
279+
}
280+
281+
unsafe impl Send for ParaformerOnlineRecognizer {}
282+
unsafe impl Sync for ParaformerOnlineRecognizer {}
283+
284+
impl Drop for ParaformerOnlineRecognizer {
285+
fn drop(&mut self) {
286+
unsafe {
287+
sherpa_rs_sys::SherpaOnnxDestroyOnlineStream(self.stream);
288+
sherpa_rs_sys::SherpaOnnxDestroyOnlineRecognizer(self.recognizer);
289+
}
290+
}
291+
}

examples/paraformer_streaming.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
Transcribe wav file using streaming Paraformer and punctuate the result
3+
4+
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
5+
tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
6+
7+
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
8+
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
9+
10+
cargo run --example paraformer_streaming motivation.wav
11+
*/
12+
13+
use sherpa_rs::{
14+
paraformer::{ParaformerOnlineConfig, ParaformerOnlineRecognizer},
15+
read_audio_file,
16+
};
17+
18+
fn main() {
19+
let path = std::env::args().nth(1).expect("Missing file path argument");
20+
let provider = std::env::args().nth(2).unwrap_or("cpu".into());
21+
let (samples, sample_rate) = read_audio_file(&path).unwrap();
22+
assert_eq!(sample_rate, 16000, "The sample rate must be 16000.");
23+
24+
let config = ParaformerOnlineConfig {
25+
tokens: "sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt".into(),
26+
provider: Some(provider),
27+
debug: true,
28+
encoder_model_path: "sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"
29+
.into(),
30+
decoder_model_path: "sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"
31+
.into(),
32+
enable_endpoint: Some(true),
33+
..Default::default()
34+
};
35+
36+
let mut recognizer = ParaformerOnlineRecognizer::new(config).unwrap();
37+
38+
for chunk in samples.chunks(1600) {
39+
let result = recognizer.transcribe(sample_rate, &chunk);
40+
if result.text.is_empty() {
41+
continue;
42+
}
43+
if result.is_final {
44+
println!("🎉 Final: {}", result.text);
45+
} else {
46+
println!("💬 Partial: {}", result.text);
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)