Skip to content

Commit 4acc882

Browse files
committed
Stabilize streaming and decoder state handling
1 parent 2f5e562 commit 4acc882

21 files changed

+1693
-1145
lines changed

Cargo.lock

Lines changed: 110 additions & 147 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src-tauri/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ thiserror = "2.0.17"
2727
ort = "2.0.0-rc.10"
2828
num_cpus = "1.17.0"
2929
enigo = "0.6.1"
30-
reqwest = { version = "0.12.25", default-features = false, features = ["blocking", "rustls-tls", "gzip"] }
30+
ureq = { version = "2.12.1", features = ["json", "charset"] }
3131
dirs-next = "2.0.0"
3232
cpal = "0.16.0"
33+
tauri-plugin-single-instance = "2.3.6"
3334
tauri-plugin-store = "2.4.1"
3435
tauri-plugin-log = "2.7.1"
3536
tauri-plugin-dialog = "2.4.2"
3637
tauri-plugin-global-shortcut = "2.3.1"
37-
tauri-plugin-single-instance = "2.3.6"
38+
rtrb = "0.3.2"
39+
rubato = "0.16.2"

src-tauri/src/app.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn handle_run_event(app_handle: &AppHandle, event: RunEvent) {
8686
fn handle_run_event(_app_handle: &AppHandle, _event: RunEvent) {}
8787

8888
fn on_second_instance(app: &AppHandle, argv: Vec<String>, cwd: String) {
89-
log::info!("{}, {argv:?}, {cwd}", app.package_info().name);
89+
log::info!("Second instance detected (args={argv:?}, cwd={cwd})");
9090
if let Err(err) = app.emit("single-instance", ()) {
9191
log::error!("Failed to emit single-instance event: {err}");
9292
}

src-tauri/src/asr/audio_io.rs

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1 @@
11
pub(crate) const TARGET_SAMPLE_RATE: u32 = 16_000;
2-
3-
pub fn resample_linear(input: &[f32], from_sr: u32, to_sr: u32) -> Vec<f32> {
4-
if from_sr == 0 || to_sr == 0 || input.is_empty() {
5-
return Vec::new();
6-
}
7-
8-
if from_sr == to_sr {
9-
return input.to_vec();
10-
}
11-
12-
let out_len = ((input.len() as f64) * (to_sr as f64) / (from_sr as f64))
13-
.ceil()
14-
.max(1.0) as usize;
15-
let step = from_sr as f64 / to_sr as f64;
16-
17-
let mut output = Vec::with_capacity(out_len);
18-
let input_len = input.len();
19-
20-
for i in 0..out_len {
21-
let pos = (i as f64) * step;
22-
let idx = pos.floor() as usize;
23-
let frac = (pos - idx as f64) as f32;
24-
25-
unsafe {
26-
let current = *input.get_unchecked(idx.min(input_len - 1));
27-
let next_idx = (idx + 1).min(input_len - 1);
28-
let next = *input.get_unchecked(next_idx);
29-
output.push(current + (next - current) * frac);
30-
}
31-
}
32-
33-
output
34-
}

src-tauri/src/asr/decoder.rs

Lines changed: 129 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use ort::inputs;
66
use ort::value::TensorRef;
77
use regex::Regex;
88

9-
use crate::asr::recognizer::{AsrError, AsrModel, Transcript};
9+
use crate::asr::recognizer::{AsrError, AsrModel, InferenceConfig, Transcript};
1010

1111
type DecoderState = (Array3<f32>, Array3<f32>);
1212

@@ -17,102 +17,72 @@ const MAX_TOKENS_PER_STEP: usize = 10;
1717
static DECODE_SPACE_RE: LazyLock<Result<Regex, regex::Error>> =
1818
LazyLock::new(|| Regex::new(r"\A\s|\s\B|(\s)\b"));
1919

20-
pub(crate) struct DecoderWorkspace {
21-
encoder_step: Array3<f32>,
22-
targets: Array2<i32>,
23-
target_length: Array1<i32>,
24-
state: DecoderState,
20+
pub struct DecoderSession<'m> {
21+
model: &'m mut AsrModel,
22+
workspace: DecoderWorkspace,
23+
last_token: i32,
2524
}
2625

27-
impl DecoderWorkspace {
28-
pub(crate) fn new(session: &ort::session::Session) -> Result<Self, AsrError> {
29-
let encoder_dim = session
30-
.inputs
31-
.iter()
32-
.find(|input| input.name == "encoder_outputs")
33-
.and_then(|input| input.input_type.tensor_shape())
34-
.and_then(|shape| shape.get(1).copied())
35-
.and_then(|d| usize::try_from(d).ok())
36-
.unwrap_or(1024);
37-
38-
let state1_shape = session
39-
.inputs
40-
.iter()
41-
.find(|input| input.name == "input_states_1")
42-
.ok_or_else(|| AsrError::InputNotFound("input_states_1".to_string()))?
43-
.input_type
44-
.tensor_shape()
45-
.ok_or_else(|| AsrError::TensorShape("input_states_1".to_string()))?;
46-
47-
let state2_shape = session
48-
.inputs
49-
.iter()
50-
.find(|input| input.name == "input_states_2")
51-
.ok_or_else(|| AsrError::InputNotFound("input_states_2".to_string()))?
52-
.input_type
53-
.tensor_shape()
54-
.ok_or_else(|| AsrError::TensorShape("input_states_2".to_string()))?;
26+
impl<'m> DecoderSession<'m> {
27+
pub(crate) fn new(
28+
model: &'m mut AsrModel,
29+
workspace: Option<DecoderWorkspace>,
30+
last_token: i32,
31+
) -> Result<Self, AsrError> {
32+
let workspace = if let Some(ws) = workspace {
33+
log::debug!("Reusing cached decoder workspace");
34+
ws
35+
} else {
36+
log::debug!("Initializing new decoder workspace");
37+
DecoderWorkspace::new(&model.decoder_joint)?
38+
};
5539

56-
let state1 = Array::zeros((state1_shape[0] as usize, 1, state1_shape[2] as usize));
57-
let state2 = Array::zeros((state2_shape[0] as usize, 1, state2_shape[2] as usize));
40+
log::debug!("Decoder session initialized with last_token={}", last_token);
5841

5942
Ok(Self {
60-
encoder_step: Array::zeros((1, encoder_dim, 1)),
61-
targets: Array2::zeros((1, 1)),
62-
target_length: Array1::from_vec(vec![1]),
63-
state: (state1, state2),
43+
model,
44+
workspace,
45+
last_token,
6446
})
6547
}
6648

67-
#[inline]
68-
pub(crate) fn reset_state(&mut self) {
69-
self.state.0.fill(0.0);
70-
self.state.1.fill(0.0);
71-
}
72-
73-
pub(crate) fn set_encoder_step(&mut self, frame: &ArrayView1<f32>) {
74-
let mut view = self.encoder_step.index_axis_mut(ndarray::Axis(2), 0);
75-
let mut view = view.index_axis_mut(ndarray::Axis(0), 0);
76-
view.assign(frame);
77-
}
78-
79-
pub(crate) fn set_target(&mut self, token: i32) {
80-
self.targets[[0, 0]] = token;
49+
pub(crate) fn into_parts(self) -> (DecoderWorkspace, i32) {
50+
(self.workspace, self.last_token)
8151
}
82-
}
8352

84-
impl AsrModel {
8553
pub(crate) fn decode_sequence(
8654
&mut self,
8755
encodings: &ArrayViewD<f32>,
8856
encodings_len: usize,
57+
_config: &InferenceConfig,
8958
) -> Result<(Vec<i32>, Vec<usize>), AsrError> {
9059
let decode_start = Instant::now();
91-
let mut tokens = Vec::with_capacity(encodings_len / 2 + 4);
92-
let mut timestamps = Vec::with_capacity(encodings_len / 2 + 4);
93-
94-
let workspace = &mut self.decoder_workspace;
95-
workspace.reset_state();
60+
let mut tokens = Vec::with_capacity(std::cmp::max(1, encodings_len / 2));
61+
let mut timestamps = Vec::with_capacity(std::cmp::max(1, encodings_len / 2));
9662

9763
let mut t = 0;
9864
let mut emitted_tokens = 0;
9965

10066
while t < encodings_len {
10167
let encoder_step = encodings.slice(ndarray::s![t, ..]);
102-
workspace.set_encoder_step(&encoder_step);
68+
self.workspace.set_encoder_step(&encoder_step);
10369

104-
let target_token = tokens.last().copied().unwrap_or(self.blank_idx);
105-
workspace.set_target(target_token);
70+
let target_token = if let Some(last) = tokens.last() {
71+
*last
72+
} else {
73+
self.last_token
74+
};
75+
self.workspace.set_target(target_token);
10676

10777
let inputs = inputs![
108-
"encoder_outputs" => TensorRef::from_array_view(workspace.encoder_step.view())?,
109-
"targets" => TensorRef::from_array_view(workspace.targets.view())?,
110-
"target_length" => TensorRef::from_array_view(workspace.target_length.view())?,
111-
"input_states_1" => TensorRef::from_array_view(workspace.state.0.view())?,
112-
"input_states_2" => TensorRef::from_array_view(workspace.state.1.view())?,
78+
"encoder_outputs" => TensorRef::from_array_view(self.workspace.encoder_step.view())?,
79+
"targets" => TensorRef::from_array_view(self.workspace.targets.view())?,
80+
"target_length" => TensorRef::from_array_view(self.workspace.target_length.view())?,
81+
"input_states_1" => TensorRef::from_array_view(self.workspace.state.0.view())?,
82+
"input_states_2" => TensorRef::from_array_view(self.workspace.state.1.view())?,
11383
];
11484

115-
let outputs = self.decoder_joint.run(inputs)?;
85+
let outputs = self.model.decoder_joint.run(inputs)?;
11686

11787
let logits = outputs
11888
.get("outputs")
@@ -127,13 +97,8 @@ impl AsrModel {
12797
))
12898
})?;
12999

130-
let vocab_logits = if logits.len() > self.vocab_size {
131-
log::trace!(
132-
"TDT model detected: splitting {} logits into vocab({}) + duration",
133-
logits.len(),
134-
self.vocab_size
135-
);
136-
&vocab_logits_slice[..self.vocab_size]
100+
let vocab_logits = if logits.len() > self.model.vocab_size {
101+
&vocab_logits_slice[..self.model.vocab_size]
137102
} else {
138103
vocab_logits_slice
139104
};
@@ -142,10 +107,9 @@ impl AsrModel {
142107
.iter()
143108
.enumerate()
144109
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
145-
.map(|(idx, _)| idx as i32)
146-
.unwrap_or(self.blank_idx);
110+
.map_or(self.model.blank_idx, |(idx, _)| idx as i32);
147111

148-
if token != self.blank_idx {
112+
if token != self.model.blank_idx {
149113
let state1 = outputs
150114
.get("output_states_1")
151115
.ok_or_else(|| AsrError::OutputNotFound("output_states_1".to_string()))?
@@ -156,26 +120,37 @@ impl AsrModel {
156120
.try_extract_array::<f32>()?;
157121

158122
if let Ok(state1_view) = state1.view().into_dimensionality::<ndarray::Ix3>() {
159-
if workspace.state.0.shape() == state1_view.shape() {
160-
workspace.state.0.assign(&state1_view);
123+
if self.workspace.state.0.shape() == state1_view.shape() {
124+
self.workspace.state.0.assign(&state1_view);
161125
} else {
162-
workspace.state.0 = state1_view.to_owned();
126+
log::warn!(
127+
"Decoder state_1 shape changed: {:?} -> {:?}",
128+
self.workspace.state.0.shape(),
129+
state1_view.shape()
130+
);
131+
self.workspace.state.0 = state1_view.to_owned();
163132
}
164133
}
165134
if let Ok(state2_view) = state2.view().into_dimensionality::<ndarray::Ix3>() {
166-
if workspace.state.1.shape() == state2_view.shape() {
167-
workspace.state.1.assign(&state2_view);
135+
if self.workspace.state.1.shape() == state2_view.shape() {
136+
self.workspace.state.1.assign(&state2_view);
168137
} else {
169-
workspace.state.1 = state2_view.to_owned();
138+
log::warn!(
139+
"Decoder state_2 shape changed: {:?} -> {:?}",
140+
self.workspace.state.1.shape(),
141+
state2_view.shape()
142+
);
143+
self.workspace.state.1 = state2_view.to_owned();
170144
}
171145
}
172146

173147
tokens.push(token);
174148
timestamps.push(t);
175149
emitted_tokens += 1;
150+
self.last_token = token;
176151
}
177152

178-
if token == self.blank_idx || emitted_tokens == MAX_TOKENS_PER_STEP {
153+
if token == self.model.blank_idx || emitted_tokens == MAX_TOKENS_PER_STEP {
179154
t += 1;
180155
emitted_tokens = 0;
181156
}
@@ -190,7 +165,74 @@ impl AsrModel {
190165

191166
Ok((tokens, timestamps))
192167
}
168+
}
169+
170+
pub(crate) struct DecoderWorkspace {
171+
encoder_step: Array3<f32>,
172+
targets: Array2<i32>,
173+
target_length: Array1<i32>,
174+
state: DecoderState,
175+
}
176+
177+
impl DecoderWorkspace {
178+
pub(crate) fn new(session: &ort::session::Session) -> Result<Self, AsrError> {
179+
let encoder_dim = session
180+
.inputs
181+
.iter()
182+
.find(|input| input.name == "encoder_outputs")
183+
.and_then(|input| input.input_type.tensor_shape())
184+
.and_then(|shape| shape.get(1).copied())
185+
.and_then(|d| usize::try_from(d).ok());
186+
187+
let encoder_dim = match encoder_dim {
188+
Some(dim) => dim,
189+
None => {
190+
log::warn!("Could not determine encoder_dim from model, falling back to 1024");
191+
1024
192+
}
193+
};
194+
195+
let state1_shape = session
196+
.inputs
197+
.iter()
198+
.find(|input| input.name == "input_states_1")
199+
.ok_or_else(|| AsrError::InputNotFound("input_states_1".to_string()))?
200+
.input_type
201+
.tensor_shape()
202+
.ok_or_else(|| AsrError::TensorShape("input_states_1".to_string()))?;
203+
204+
let state2_shape = session
205+
.inputs
206+
.iter()
207+
.find(|input| input.name == "input_states_2")
208+
.ok_or_else(|| AsrError::InputNotFound("input_states_2".to_string()))?
209+
.input_type
210+
.tensor_shape()
211+
.ok_or_else(|| AsrError::TensorShape("input_states_2".to_string()))?;
212+
213+
let state1 = Array::zeros((state1_shape[0] as usize, 1, state1_shape[2] as usize));
214+
let state2 = Array::zeros((state2_shape[0] as usize, 1, state2_shape[2] as usize));
215+
216+
Ok(Self {
217+
encoder_step: Array::zeros((1, encoder_dim, 1)),
218+
targets: Array2::zeros((1, 1)),
219+
target_length: Array1::from_vec(vec![1]),
220+
state: (state1, state2),
221+
})
222+
}
223+
224+
pub(crate) fn set_encoder_step(&mut self, frame: &ArrayView1<f32>) {
225+
let mut view = self.encoder_step.index_axis_mut(ndarray::Axis(2), 0);
226+
let mut view = view.index_axis_mut(ndarray::Axis(0), 0);
227+
view.assign(frame);
228+
}
193229

230+
pub(crate) fn set_target(&mut self, token: i32) {
231+
self.targets[[0, 0]] = token;
232+
}
233+
}
234+
235+
impl AsrModel {
194236
pub(crate) fn decode_tokens(&self, ids: Vec<i32>, timestamps: Vec<usize>) -> Transcript {
195237
let tokens: Vec<String> = ids
196238
.iter()

src-tauri/src/asr/mod.rs

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,17 @@ pub mod download_progress;
44
mod model_store;
55
mod recognizer;
66

7-
use crate::vad::VadModel;
8-
use std::sync::{Arc, OnceLock};
97
use tauri::AppHandle;
108

11-
static VAD_MODEL: OnceLock<Arc<VadModel>> = OnceLock::new();
12-
13-
pub fn get_or_init_vad_model(app: &AppHandle) -> Arc<VadModel> {
14-
VAD_MODEL
15-
.get_or_init(|| {
16-
let path = model_store::vad_model_path(app);
17-
match VadModel::new(&path) {
18-
Ok(m) => Arc::new(m),
19-
Err(e) => {
20-
log::error!("Failed to load VAD model: {e}");
21-
panic!("VAD model failed to load at {}: {e}", path.display());
22-
}
23-
}
24-
})
25-
.clone()
9+
pub fn get_or_init_vad_model(app: &AppHandle) -> Result<std::path::PathBuf, String> {
10+
model_store::ensure_vad_model(app).map_err(|e| e.user_message().to_string())
2611
}
2712

2813
pub use model_store::{
29-
default_model_root, ensure_vad_model, fallback_model_root, resolve_model_dir, vad_model_path,
14+
default_model_root, ensure_vad_model, fallback_model_root, missing_model_files_for_tests,
15+
resolve_model_dir, vad_model_path,
3016
};
3117
pub use recognizer::{AsrError, AsrModel, Transcript};
3218

33-
pub(crate) use audio_io::{resample_linear, TARGET_SAMPLE_RATE};
19+
pub(crate) use audio_io::TARGET_SAMPLE_RATE;
3420
pub(crate) use download_progress::{current_download_progress, record_failure, DownloadProgress};

0 commit comments

Comments
 (0)