Skip to content

Commit 89b1913

Browse files
authored
Merge pull request #37 from second-state/feat/stream_asr
Feat/stream asr
2 parents 36ee7c4 + c738ca8 commit 89b1913

13 files changed

Lines changed: 7609 additions & 1471 deletions

File tree

Cargo.lock

Lines changed: 6626 additions & 1305 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ rmcp = { version = "0.1.5", features = [
5151
"transport-streamable-http-client",
5252
"reqwest",
5353
], default-features = false, git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "b9d7d61" } # branch = "main"
54+
5455
base64 = "0.22.1"
5556
reqwest-websocket = "0.5.0"
5657
futures-util = "0.3.31"
@@ -59,3 +60,7 @@ tower = { version = "0.5.2", features = ["util"] }
5960
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
6061

6162
chrono = "0.4.41"
63+
64+
# vad
65+
silero_vad_burn = "0.1.1"
66+
burn = { version = "0.20", features = ["ndarray"] }

src/ai/bailian/realtime_asr.rs

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ impl ParaformerRealtimeV2Asr {
9393
Ok(())
9494
}
9595

96-
pub async fn start_pcm_recognition(&mut self) -> anyhow::Result<()> {
96+
pub async fn start_pcm_recognition(
97+
&mut self,
98+
semantic_punctuation_enabled: bool,
99+
) -> anyhow::Result<()> {
97100
let task_id = Uuid::new_v4().to_string();
98101
log::info!("Starting asr task with ID: {}", task_id);
99102
self.task_id = task_id;
@@ -112,6 +115,7 @@ impl ParaformerRealtimeV2Asr {
112115
"parameters": {
113116
"format": "pcm",
114117
"sample_rate": self.sample_rate,
118+
"semantic_punctuation_enabled": semantic_punctuation_enabled,
115119
},
116120
"input": {}
117121
},
@@ -163,6 +167,7 @@ impl ParaformerRealtimeV2Asr {
163167
"streaming": "duplex"
164168
},
165169
"payload": {
170+
"task_group": "audio",
166171
"input": {}
167172
}
168173
});
@@ -197,6 +202,7 @@ impl ParaformerRealtimeV2Asr {
197202
} else if let Some(output) = response.payload.output {
198203
return Ok(Some(output.sentence));
199204
} else {
205+
log::error!("ASR response has no output: {:?}", text);
200206
return Err(anyhow::anyhow!("ASR error: {:?}", text));
201207
}
202208
}
@@ -226,31 +232,116 @@ async fn test_paraformer_asr() {
226232
let mut asr = ParaformerRealtimeV2Asr::connect("", token, head.sample_rate)
227233
.await
228234
.unwrap();
229-
asr.start_pcm_recognition().await.unwrap();
235+
asr.start_pcm_recognition(false).await.unwrap();
230236

231237
asr.send_audio(audio_data.clone()).await.unwrap();
232238
asr.finish_task().await.unwrap();
233239

234240
loop {
235241
if let Ok(Some(sentence)) = asr.next_result().await {
236-
println!("{:?}", sentence);
242+
log::info!("{:?}", sentence);
237243
if sentence.sentence_end {
238-
println!();
244+
log::info!("Final sentence received, ending recognition session.");
239245
}
240246
} else {
241247
break;
242248
}
243249
}
244250

245-
asr.start_pcm_recognition().await.unwrap();
251+
asr.start_pcm_recognition(false).await.unwrap();
246252
asr.send_audio(audio_data).await.unwrap();
247253
asr.finish_task().await.unwrap();
248254

249255
loop {
250256
if let Ok(Some(sentence)) = asr.next_result().await {
251-
println!("{:?}", sentence);
257+
log::info!("{:?}", sentence);
258+
if sentence.sentence_end {
259+
log::info!("Final sentence received, ending recognition session.");
260+
}
261+
} else {
262+
break;
263+
}
264+
}
265+
}
266+
267+
// cargo test --package echokit_server --bin echokit_server -- ai::bailian::realtime_asr::test_paraformer_stream_asr --exact --show-output
268+
#[tokio::test]
269+
async fn test_paraformer_stream_asr() {
270+
env_logger::init();
271+
let token = std::env::var("COSYVOICE_TOKEN").unwrap();
272+
273+
let data = std::fs::read("./resources/test/out.wav").unwrap();
274+
let mut reader = wav_io::reader::Reader::from_vec(data).expect("Failed to create WAV reader");
275+
let header = reader.read_header().unwrap();
276+
let mut samples = crate::util::get_samples_f32(&mut reader).unwrap();
277+
278+
// pad 10 seconds of silence
279+
samples.extend_from_slice(&[0.0; 16000 * 10]);
280+
281+
let samples = crate::util::convert_samples_f32_to_i16_bytes(&samples);
282+
let audio_data = bytes::Bytes::from(samples);
283+
284+
let mut asr = ParaformerRealtimeV2Asr::connect("", token, header.sample_rate)
285+
.await
286+
.unwrap();
287+
asr.start_pcm_recognition(true).await.unwrap();
288+
289+
let mut ms = 0;
290+
291+
for chunk in audio_data.chunks(3200) {
292+
ms += 100;
293+
log::info!("Sending audio chunk at {} ms", ms);
294+
asr.send_audio(Bytes::copy_from_slice(chunk)).await.unwrap();
295+
// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
296+
let wait_asr_fut = asr.next_result();
297+
298+
let (sentence, has_result) = tokio::select! {
299+
res = wait_asr_fut => {
300+
(res.unwrap(),true)
301+
}
302+
_ = async {} => {
303+
(None,false)
304+
}
305+
};
306+
307+
if has_result {
308+
log::info!("{:?} {ms}", sentence);
309+
}
310+
311+
if let Some(s) = sentence {
312+
if s.sentence_end {
313+
break;
314+
}
315+
}
316+
}
317+
318+
asr.finish_task().await.unwrap();
319+
320+
loop {
321+
if let Ok(Some(sentence)) = asr.next_result().await {
322+
log::info!("{:?}", sentence);
323+
if sentence.sentence_end {
324+
log::info!("End of sentence");
325+
}
326+
} else {
327+
break;
328+
}
329+
}
330+
331+
asr.start_pcm_recognition(true).await.unwrap();
332+
333+
ms = 0;
334+
for chunk in audio_data.chunks(3200) {
335+
ms += 100;
336+
log::info!("Sending audio chunk at {} ms", ms);
337+
asr.send_audio(Bytes::copy_from_slice(chunk)).await.unwrap();
338+
// tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
339+
}
340+
loop {
341+
if let Ok(Some(sentence)) = asr.next_result().await {
342+
log::info!("{:?}", sentence);
252343
if sentence.sentence_end {
253-
println!();
344+
log::info!("End of sentence");
254345
}
255346
} else {
256347
break;

src/ai/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
473473
serde_json::to_string_pretty(&serde_json::json!(
474474
{
475475
"stream": true,
476-
"messages": messages,
476+
"last_message": messages.last(),
477477
"model": model.to_string(),
478478
"tools": tool_name,
479479
"extra": extra,

src/ai/vad.rs

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use futures_util::{
2-
stream::{SplitSink, SplitStream},
32
SinkExt, StreamExt,
3+
stream::{SplitSink, SplitStream},
44
};
55
use reqwest::multipart::Part;
66
use reqwest_websocket::{RequestBuilderExt, WebSocket};
@@ -101,3 +101,109 @@ impl VadRealtimeRx {
101101
}
102102
}
103103
}
104+
105+
pub type VadParams = crate::config::SileroVadConfig;
106+
107+
#[derive(Clone)]
108+
pub struct SileroVADFactory {
109+
device: burn::backend::ndarray::NdArrayDevice,
110+
params: VadParams,
111+
}
112+
113+
impl SileroVADFactory {
114+
pub fn new(params: VadParams) -> anyhow::Result<Self> {
115+
let device = burn::backend::ndarray::NdArrayDevice::default();
116+
117+
Ok(SileroVADFactory { device, params })
118+
}
119+
120+
pub fn create_session(&self) -> anyhow::Result<VadSession> {
121+
let vad = Box::new(silero_vad_burn::SileroVAD6Model::new(&self.device)?);
122+
VadSession::new(&self.params, vad, self.device.clone())
123+
}
124+
}
125+
126+
pub struct VadSession {
127+
vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>,
128+
state: Option<silero_vad_burn::PredictState<burn::backend::NdArray>>,
129+
device: burn::backend::ndarray::NdArrayDevice,
130+
131+
in_speech: bool,
132+
133+
threshold: f32,
134+
neg_threshold: f32,
135+
136+
silence_chunk_count: usize,
137+
max_silence_chunks: usize,
138+
}
139+
140+
impl VadSession {
141+
const SAMPLE_RATE: usize = 16000;
142+
143+
pub fn new(
144+
params: &VadParams,
145+
vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>,
146+
device: burn::backend::ndarray::NdArrayDevice,
147+
) -> anyhow::Result<Self> {
148+
let state = Some(silero_vad_burn::PredictState::default(&device));
149+
150+
let neg_threshold = params
151+
.neg_threshold
152+
.unwrap_or_else(|| params.threshold - 0.15)
153+
.max(0.05);
154+
155+
let threshold = params.threshold.min(0.95);
156+
let max_silence_chunks = params.max_silence_duration_ms * (Self::SAMPLE_RATE / 1000)
157+
/ silero_vad_burn::CHUNK_SIZE;
158+
159+
Ok(VadSession {
160+
vad,
161+
state,
162+
device,
163+
164+
in_speech: false,
165+
threshold,
166+
neg_threshold,
167+
168+
silence_chunk_count: 0,
169+
max_silence_chunks,
170+
})
171+
}
172+
173+
pub fn reset_state(&mut self) {
174+
self.state = Some(silero_vad_burn::PredictState::default(&self.device));
175+
self.in_speech = false;
176+
self.silence_chunk_count = 0;
177+
}
178+
179+
pub fn detect(&mut self, audio16k_chunk_512: &[f32]) -> anyhow::Result<bool> {
180+
debug_assert!(
181+
audio16k_chunk_512.len() <= 512,
182+
"audio16k_chunk_512 length must be less than 512",
183+
);
184+
185+
let audio_tensor =
186+
burn::Tensor::<_, 1>::from_floats(audio16k_chunk_512, &self.device).unsqueeze();
187+
let (state, prob) = self.vad.predict(self.state.take().unwrap(), audio_tensor)?;
188+
self.state = Some(state);
189+
190+
let prob: Vec<f32> = prob.to_data().to_vec()?;
191+
192+
if prob[0] > self.threshold {
193+
self.in_speech = true;
194+
self.silence_chunk_count = 0;
195+
} else if prob[0] < self.neg_threshold {
196+
self.silence_chunk_count += 1;
197+
if self.silence_chunk_count >= self.max_silence_chunks {
198+
self.in_speech = false;
199+
}
200+
} else {
201+
}
202+
203+
Ok(self.in_speech)
204+
}
205+
206+
pub const fn vad_chunk_size() -> usize {
207+
silero_vad_burn::CHUNK_SIZE
208+
}
209+
}

src/config.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,54 @@ pub enum TTSConfig {
242242
Elevenlabs(ElevenlabsTTS),
243243
}
244244

245+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
246+
pub struct SileroVadConfig {
247+
#[serde(default = "SileroVadConfig::default_threshold")]
248+
pub threshold: f32,
249+
#[serde(default = "SileroVadConfig::default_neg_threshold")]
250+
pub neg_threshold: Option<f32>,
251+
#[serde(default = "SileroVadConfig::default_min_speech_duration_ms")]
252+
pub min_speech_duration_ms: usize,
253+
#[serde(default = "SileroVadConfig::default_max_silence_duration_ms")]
254+
pub max_silence_duration_ms: usize,
255+
#[serde(default = "SileroVadConfig::hangover_ms")]
256+
pub hangover_ms: usize,
257+
}
258+
259+
impl SileroVadConfig {
260+
pub fn default_threshold() -> f32 {
261+
0.5
262+
}
263+
264+
pub fn default_neg_threshold() -> Option<f32> {
265+
None
266+
}
267+
268+
pub fn default_min_speech_duration_ms() -> usize {
269+
400
270+
}
271+
272+
pub fn default_max_silence_duration_ms() -> usize {
273+
200
274+
}
275+
276+
pub fn hangover_ms() -> usize {
277+
500
278+
}
279+
}
280+
281+
impl Default for SileroVadConfig {
282+
fn default() -> Self {
283+
SileroVadConfig {
284+
threshold: Self::default_threshold(),
285+
neg_threshold: Self::default_neg_threshold(),
286+
min_speech_duration_ms: Self::default_min_speech_duration_ms(),
287+
max_silence_duration_ms: Self::default_max_silence_duration_ms(),
288+
hangover_ms: Self::hangover_ms(),
289+
}
290+
}
291+
}
292+
245293
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
246294
pub struct WhisperASRConfig {
247295
pub url: String,
@@ -253,8 +301,14 @@ pub struct WhisperASRConfig {
253301
pub model: String,
254302
#[serde(default)]
255303
pub prompt: String,
304+
305+
#[serde(default)]
306+
pub vad: SileroVadConfig,
307+
308+
#[deprecated]
256309
#[serde(default)]
257310
pub vad_url: Option<String>,
311+
#[deprecated]
258312
#[serde(default)]
259313
pub vad_realtime_url: Option<String>,
260314
}

src/main.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::sync::Arc;
22

3-
use axum::{routing::any, Router};
3+
use axum::{
4+
Router,
5+
routing::{any, get},
6+
};
47
use clap::Parser;
58
use config::Config;
69

@@ -180,5 +183,13 @@ async fn routes(
180183
.layer(axum::Extension(Arc::new(real_config)));
181184
}
182185

183-
router
186+
router.route(
187+
"/version",
188+
get(|| async {
189+
axum::response::Json(serde_json::json!(
190+
{
191+
"version": env!("CARGO_PKG_VERSION"),
192+
}))
193+
}),
194+
)
184195
}

0 commit comments

Comments
 (0)