Skip to content

Commit 5ea8641

Browse files
zhmiaoCopilot
andcommitted
feat(rp-27 part 2): opt-in second ONNX input orig_sample_rate for raw_audio
Extends PreprocessMethod::RawAudio with pass_orig_sample_rate: bool (default false). When true, the CPU detect_audio raw-path engine passes a second ONNX input "orig_sample_rate" [1] int64 alongside the audio tensor. Models that opt in can implement in-graph fill_highfreq (or any other data-driven sample-rate behavior) without engine-side mel knowledge. Wiring: - PreparedAudioKind::Raw gained pass_orig_sample_rate field; carried from manifest through prepare_audio_detection into detect_audio_loop_raw. - session.run dispatches via the named-input form Vec<(Cow<str>, SessionInputValue)> when the flag is on. - resolve_classifier_output (the model-load probe at line 292) also needed the second input — probes with orig_sample_rate=target_sample_rate (no-op for fill_highfreq). Manifest schema: [preprocessing] pass_orig_sample_rate=true|false (default false). Existing perch-v2 + future single-input raw_audio models keep working unchanged. Motivation: orca-ecotype-dclde2026-v1 (RP-onboarding-2026-06-01) uses raw_audio + softmax with the same fill_highfreq requirement as Stage 1. Engine-side mel fill (RP-27 Part 1) doesn't reach the in-graph mel path, so Stage 2 needed its own data-driven fix. The exported ONNX wrapper implements fill_highfreq via sort + dynamic-k gather; the engine just ships orig_sample_rate so the in-graph mask boundary tracks it. Stage 2 parity (300 windows, 10 fixtures): top1 prob delta mean 0.027, median 0.018, max 0.181; argmax flip rate 4.0% (gate 15%). 24 kHz fixture (no-resample, fill no-op) matches engine bit-exactly because the in-graph fill is guarded with no_upsample = (orig_sample_rate >= target_sample_rate) matching upstream PW's `if orig_sr < SR` guard. Lib tests: types 123 + core 178 + cpu 74 = 375 PASS. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5a1ce64 commit 5ea8641

4 files changed

Lines changed: 74 additions & 10 deletions

File tree

sparrow-engine/sparrow-engine-cpu/src/detect_audio.rs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,12 @@ enum PreparedAudioKind {
133133
/// tensor named "label" at session-load time; for single-output
134134
/// models this is just `0`.
135135
logits_output_idx: usize,
136-
/// Number of classes (= length of the softmax distribution).
136+
/// Number of softmax classes the model emits.
137137
num_classes: usize,
138+
/// Opt-in (RP-27 Part 2, 2026-06-05): when true, the engine passes a
139+
/// second ONNX input `orig_sample_rate [1] int64` alongside the
140+
/// audio tensor. Used by in-graph fill_highfreq.
141+
pass_orig_sample_rate: bool,
138142
},
139143
}
140144

@@ -216,6 +220,7 @@ fn prepare_audio_detection(
216220
}
217221
PreprocessMethod::RawAudio {
218222
window_samples,
223+
pass_orig_sample_rate,
219224
..
220225
} => {
221226
let segment_samples = *window_samples as usize;
@@ -251,14 +256,17 @@ fn prepare_audio_detection(
251256

252257
// Resolve the logits output: prefer the tensor named "label" (Perch 2),
253258
// fall back to output 0 for single-head softmax classifiers.
259+
// When pass_orig_sample_rate=true, probe with a dummy orig_sr=sample_rate
260+
// (the no-op case for fill_highfreq) so the 2-input ONNX accepts the call.
254261
let (logits_output_idx, num_classes) =
255-
resolve_classifier_output(handle, segment_samples)?;
262+
resolve_classifier_output(handle, segment_samples, *pass_orig_sample_rate, sample_rate)?;
256263

257264
Ok(PreparedAudioDetection {
258265
audio_samples,
259266
kind: PreparedAudioKind::Raw {
260267
logits_output_idx,
261268
num_classes,
269+
pass_orig_sample_rate: *pass_orig_sample_rate,
262270
},
263271
segment_samples,
264272
stride_samples,
@@ -286,6 +294,8 @@ fn prepare_audio_detection(
286294
fn resolve_classifier_output(
287295
handle: &ModelHandle,
288296
window_samples: usize,
297+
pass_orig_sample_rate: bool,
298+
target_sample_rate: u32,
289299
) -> Result<(usize, usize)> {
290300
let session = handle.pin_session()?;
291301
let mut guard = session
@@ -302,9 +312,23 @@ fn resolve_classifier_output(
302312
// Probe with one zero-filled window to learn the class count.
303313
let probe = ndarray::Array2::<f32>::zeros((1, window_samples));
304314
let input_value = TensorRef::from_array_view(&probe).map_err(crate::engine::ort_err)?;
305-
let outputs = guard
306-
.run(ort::inputs![input_value])
307-
.map_err(crate::engine::ort_err)?;
315+
// RP-27 Part 2: 2-input ONNX needs orig_sample_rate populated even at probe
316+
// time. Use target_sample_rate (the no-op case for fill_highfreq).
317+
let probe_sr_arr;
318+
let outputs = if pass_orig_sample_rate {
319+
probe_sr_arr = ndarray::Array1::from_vec(vec![target_sample_rate as i64]);
320+
let orig_sr_value =
321+
TensorRef::from_array_view(&probe_sr_arr).map_err(crate::engine::ort_err)?;
322+
let inputs: Vec<(std::borrow::Cow<'_, str>, ort::session::SessionInputValue<'_>)> = vec![
323+
(std::borrow::Cow::Borrowed("audio"), input_value.into()),
324+
(std::borrow::Cow::Borrowed("orig_sample_rate"), orig_sr_value.into()),
325+
];
326+
guard.run(inputs).map_err(crate::engine::ort_err)?
327+
} else {
328+
guard
329+
.run(ort::inputs![input_value])
330+
.map_err(crate::engine::ort_err)?
331+
};
308332
if outputs.len() <= logits_idx {
309333
return Err(SparrowEngineError::Ort(format!(
310334
"classifier session probe returned {} outputs; expected at least {}",
@@ -551,11 +575,12 @@ fn detect_audio_loop_raw(
551575
start: Instant,
552576
mut on_segment: Option<&mut dyn FnMut(&AudioSegment)>,
553577
) -> Result<AudioDetectResult> {
554-
let (logits_output_idx, num_classes) = match &prep.kind {
578+
let (logits_output_idx, num_classes, pass_orig_sample_rate) = match &prep.kind {
555579
PreparedAudioKind::Raw {
556580
logits_output_idx,
557581
num_classes,
558-
} => (*logits_output_idx, *num_classes),
582+
pass_orig_sample_rate,
583+
} => (*logits_output_idx, *num_classes, *pass_orig_sample_rate),
559584
_ => unreachable!("guarded by detect_audio_loop dispatch"),
560585
};
561586

@@ -608,9 +633,26 @@ fn detect_audio_loop_raw(
608633
let mut guard = session
609634
.lock()
610635
.map_err(|_| SparrowEngineError::Ort("audio session lock poisoned".into()))?;
611-
let outputs = guard
612-
.run(ort::inputs![input_value])
613-
.map_err(crate::engine::ort_err)?;
636+
// RP-27 Part 2: when manifest opts in, pass orig_sample_rate as a
637+
// second [1] int64 input alongside the audio tensor. The exported
638+
// ONNX must declare two inputs in this order: ("audio", "orig_sample_rate").
639+
let orig_sr_arr;
640+
let outputs = if pass_orig_sample_rate {
641+
orig_sr_arr = ndarray::Array1::from_vec(vec![
642+
prep.audio_samples.orig_sample_rate as i64,
643+
]);
644+
let orig_sr_value =
645+
TensorRef::from_array_view(&orig_sr_arr).map_err(crate::engine::ort_err)?;
646+
let inputs: Vec<(std::borrow::Cow<'_, str>, ort::session::SessionInputValue<'_>)> = vec![
647+
(std::borrow::Cow::Borrowed("audio"), input_value.into()),
648+
(std::borrow::Cow::Borrowed("orig_sample_rate"), orig_sr_value.into()),
649+
];
650+
guard.run(inputs).map_err(crate::engine::ort_err)?
651+
} else {
652+
guard
653+
.run(ort::inputs![input_value])
654+
.map_err(crate::engine::ort_err)?
655+
};
614656
if outputs.len() <= logits_output_idx {
615657
return Err(SparrowEngineError::Ort(format!(
616658
"audio classifier returned {} outputs; expected at least {}",

sparrow-engine/sparrow-engine-cpu/src/engine.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,7 @@ mod tests {
12401240
PreprocessMethod::RawAudio {
12411241
sample_rate: 32_000,
12421242
window_samples: 160_000,
1243+
pass_orig_sample_rate: false,
12431244
}
12441245
}
12451246

sparrow-engine/sparrow-engine-types/src/manifest.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ pub enum PreprocessMethod {
5353
RawAudio {
5454
sample_rate: u32,
5555
window_samples: u32,
56+
/// Opt-in: when true, engine passes a second ONNX input
57+
/// `orig_sample_rate [1] int64` carrying the original (pre-resample)
58+
/// sample rate. Used by in-graph fill_highfreq passes that need to
59+
/// know whether the audio was upsampled and where the original
60+
/// Nyquist sat. Default false preserves Perch 2 / single-input
61+
/// RawAudio behavior (RP-27 Part 2, 2026-06-05).
62+
pass_orig_sample_rate: bool,
5663
},
5764
}
5865

@@ -406,6 +413,12 @@ struct RawPreprocessing {
406413
/// Number of samples per inference window (= segment_duration_s × sample_rate).
407414
/// Required for `raw_audio`. For Perch 2: 160000 = 5 s × 32 kHz.
408415
window_samples: Option<u32>,
416+
/// RawAudio-only opt-in (RP-27 Part 2, 2026-06-05): when true, the
417+
/// engine passes a second ONNX input `orig_sample_rate [1] int64`
418+
/// alongside the audio tensor so the model can apply in-graph
419+
/// fill_highfreq.
420+
#[serde(default)]
421+
pass_orig_sample_rate: Option<bool>,
409422
/// Opt-in high-frequency fill for mel_spectrogram preprocess (RP-27).
410423
/// Defaults to `false` (md-audiobirds-v1 behavior). When `true` and the
411424
/// engine resamples upward, mel bins above `orig_sr/2 - 2500 Hz` are
@@ -538,6 +551,10 @@ pub fn load_manifest(path: &Path) -> Result<ModelManifest> {
538551
.preprocessing
539552
.window_samples
540553
.ok_or_else(|| raw_err("window_samples"))?,
554+
pass_orig_sample_rate: raw
555+
.preprocessing
556+
.pass_orig_sample_rate
557+
.unwrap_or(false),
541558
}
542559
}
543560
"mel_spectrogram" => {
@@ -668,6 +685,7 @@ pub fn load_manifest(path: &Path) -> Result<ModelManifest> {
668685
if let PreprocessMethod::RawAudio {
669686
sample_rate,
670687
window_samples,
688+
..
671689
} = &preprocess_method
672690
{
673691
if *sample_rate == 0 {
@@ -825,6 +843,7 @@ pub fn load_manifest(path: &Path) -> Result<ModelManifest> {
825843
PreprocessMethod::RawAudio {
826844
sample_rate,
827845
window_samples,
846+
..
828847
},
829848
InferenceStrategy::SlidingWindow {
830849
segment_duration_s, ..
@@ -2033,6 +2052,7 @@ format = "one_per_line"
20332052
PreprocessMethod::RawAudio {
20342053
sample_rate: 32000,
20352054
window_samples: 160000,
2055+
..
20362056
}
20372057
));
20382058
assert_eq!(

sparrow-engine/sparrow-engine-types/src/model_type.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ mod phase_a_r1_model_type_tests {
7272
PreprocessMethod::RawAudio {
7373
sample_rate: 32000,
7474
window_samples: 160000,
75+
pass_orig_sample_rate: false,
7576
}
7677
}
7778

0 commit comments

Comments
 (0)