Skip to content

Commit 20da10f

Browse files
zhmiaoCopilot
andcommitted
feat(rp-27 part 2): mirror RawAudio orig_sample_rate dispatch into GPU path
The CPU detect_audio path got the 2-input ONNX dispatch in commit 5ea8641 (sparrow-engine v0.1.16). The GPU mirror in sparrow-engine-gpu was missed — spe-gpu loading orca-ecotype-dclde2026-v1 (the first model that opts into pass_orig_sample_rate=true) failed at both load-time probe and per-batch session.run with 'Missing Input: orig_sample_rate'. Wires the manifest flag through the GPU RawAudioModel: - RawAudio variant pattern: extract pass_orig_sample_rate alongside sample_rate + window_samples - New struct field on RawAudioModel - load_from_manifest probe: when flag is true, build a [1] int64 tensor carrying target_sample_rate (no-op for fill_highfreq) and run via the named Vec<(Cow<str>, SessionInputValue)> form - detect_inner per-batch: same pattern, but populates with the actual audio_samples.orig_sample_rate from the engine resampler End-to-end verified on RTX 6000 Ada with all 10 hydrophone WAV fixtures through spe-gpu detect-audio --visualize. Per-window calibrated softmax output matches the CPU smoke from earlier in the session bit-identically on the 24 kHz (no-resample) fixture; under-sampled fixtures differ within the same parity envelope as the CPU path (median top-1 delta 0.018). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 8017a89 commit 20da10f

1 file changed

Lines changed: 47 additions & 8 deletions

File tree

sparrow-engine/sparrow-engine-gpu/src/models/audio_raw.rs

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ pub struct RawAudioModel {
8484
/// (emits one segment per window regardless), but the value is held
8585
/// here for parity with the mel path + future API consistency.
8686
threshold: f32,
87+
/// RP-27 Part 2 opt-in: when true, pass orig_sample_rate as a second
88+
/// ONNX input alongside the audio tensor. Forwarded from the manifest.
89+
pass_orig_sample_rate: bool,
8790
}
8891

8992
// SAFETY: `Session` is wrapped in `Mutex` (Send + Sync via Mutex's bounds).
@@ -102,11 +105,12 @@ impl RawAudioModel {
102105
manifest_dir: &Path,
103106
) -> Result<Self> {
104107
// 1. Extract sample_rate + window_samples from the RawAudio variant.
105-
let (sample_rate, segment_samples) = match &manifest.preprocess_method {
108+
let (sample_rate, segment_samples, pass_orig_sample_rate) = match &manifest.preprocess_method {
106109
PreprocessMethod::RawAudio {
107110
sample_rate,
108111
window_samples,
109-
} => (*sample_rate, *window_samples as usize),
112+
pass_orig_sample_rate,
113+
} => (*sample_rate, *window_samples as usize, *pass_orig_sample_rate),
110114
other => {
111115
return Err(SparrowEngineError::NotAnAudioModel {
112116
id: manifest.id.clone(),
@@ -210,9 +214,25 @@ impl RawAudioModel {
210214
let probe_input = ndarray::Array2::<f32>::zeros((1, segment_samples));
211215
let probe_value = TensorRef::from_array_view(&probe_input)
212216
.map_err(|e| SparrowEngineError::Ort(format!("probe TensorRef: {e}")))?;
213-
let probe_outputs = probe_session
214-
.run(ort::inputs![probe_value])
215-
.map_err(|e| SparrowEngineError::Ort(format!("probe Session::run: {e}")))?;
217+
// RP-27 Part 2: 2-input ONNX needs orig_sample_rate populated even
218+
// at probe time. Use sample_rate (the no-op case for fill_highfreq).
219+
let probe_sr_arr;
220+
let probe_outputs = if pass_orig_sample_rate {
221+
probe_sr_arr = ndarray::Array1::from_vec(vec![sample_rate as i64]);
222+
let probe_sr_value = TensorRef::from_array_view(&probe_sr_arr)
223+
.map_err(|e| SparrowEngineError::Ort(format!("probe orig_sr TensorRef: {e}")))?;
224+
let inputs: Vec<(std::borrow::Cow<'_, str>, ort::session::SessionInputValue<'_>)> = vec![
225+
(std::borrow::Cow::Borrowed("audio"), probe_value.into()),
226+
(std::borrow::Cow::Borrowed("orig_sample_rate"), probe_sr_value.into()),
227+
];
228+
probe_session
229+
.run(inputs)
230+
.map_err(|e| SparrowEngineError::Ort(format!("probe Session::run (2-input): {e}")))?
231+
} else {
232+
probe_session
233+
.run(ort::inputs![probe_value])
234+
.map_err(|e| SparrowEngineError::Ort(format!("probe Session::run: {e}")))?
235+
};
216236
if probe_outputs.len() <= logits_output_idx {
217237
return Err(SparrowEngineError::Ort(format!(
218238
"raw-audio probe: model has {} outputs, expected logits at index {}",
@@ -281,6 +301,7 @@ impl RawAudioModel {
281301
segment_samples,
282302
stride_samples,
283303
threshold,
304+
pass_orig_sample_rate,
284305
})
285306
}
286307

@@ -395,9 +416,27 @@ impl RawAudioModel {
395416
.session
396417
.lock()
397418
.map_err(|_| SparrowEngineError::Ort("raw audio session lock poisoned".into()))?;
398-
let outputs = guard
399-
.run(ort::inputs![input_value])
400-
.map_err(|e| SparrowEngineError::Ort(format!("Session::run: {e}")))?;
419+
// RP-27 Part 2: when manifest opts in, pass orig_sample_rate as a
420+
// second ONNX input (same as CPU path).
421+
let orig_sr_arr;
422+
let outputs = if self.pass_orig_sample_rate {
423+
orig_sr_arr = ndarray::Array1::from_vec(vec![
424+
audio_samples.orig_sample_rate as i64,
425+
]);
426+
let orig_sr_value = TensorRef::from_array_view(&orig_sr_arr)
427+
.map_err(|e| SparrowEngineError::Ort(format!("orig_sr TensorRef: {e}")))?;
428+
let inputs: Vec<(std::borrow::Cow<'_, str>, ort::session::SessionInputValue<'_>)> = vec![
429+
(std::borrow::Cow::Borrowed("audio"), input_value.into()),
430+
(std::borrow::Cow::Borrowed("orig_sample_rate"), orig_sr_value.into()),
431+
];
432+
guard
433+
.run(inputs)
434+
.map_err(|e| SparrowEngineError::Ort(format!("Session::run (2-input): {e}")))?
435+
} else {
436+
guard
437+
.run(ort::inputs![input_value])
438+
.map_err(|e| SparrowEngineError::Ort(format!("Session::run: {e}")))?
439+
};
401440
if outputs.len() <= self.logits_output_idx {
402441
return Err(SparrowEngineError::Ort(format!(
403442
"raw-audio classifier returned {} outputs; expected at least {}",

0 commit comments

Comments
 (0)