diff --git a/CMakeLists.txt b/CMakeLists.txt index 330c49933..b82bbfcaf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -777,6 +777,12 @@ target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS}) target_link_libraries(ocos_operators PRIVATE ${ocos_libraries}) file(GLOB _TARGET_LIB_SRC "shared/lib/*.cc") + +if(OCOS_ENABLE_AUDIO) + file(GLOB nemo_mel_SRC "shared/api/nemo_mel_*") + list(APPEND _TARGET_LIB_SRC ${nemo_mel_SRC}) +endif() + if(OCOS_ENABLE_C_API) file(GLOB utils_TARGET_SRC "shared/api/c_api_utils.*" "shared/api/runner.hpp") list(APPEND _TARGET_LIB_SRC ${utils_TARGET_SRC}) @@ -874,6 +880,7 @@ endif() target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS}) target_include_directories(ortcustomops PUBLIC "$") target_include_directories(ortcustomops PUBLIC "$") +target_include_directories(ortcustomops PUBLIC "${PROJECT_SOURCE_DIR}/shared/api") target_link_libraries(ortcustomops PUBLIC ocos_operators) diff --git a/operators/math/dlib/stft_norm.hpp b/operators/math/dlib/stft_norm.hpp index f9769964b..1a76e4375 100644 --- a/operators/math/dlib/stft_norm.hpp +++ b/operators/math/dlib/stft_norm.hpp @@ -64,5 +64,17 @@ struct StftNormal { window[n] = static_cast(n_sin * n_sin); } + return window; + } + + // Symmetric Hann window: matches torch.hann_window(N, periodic=False). + // Uses the classic cosine formula with denominator (N-1). + static std::vector hann_window_symmetric(int N) { + std::vector window(N); + + for (int n = 0; n < N; ++n) { + window[n] = 0.5f * (1.0f - std::cos(2.0f * static_cast(M_PI) * n / (N - 1))); + } + return window; } \ No newline at end of file diff --git a/shared/api/nemo_mel_spectrogram.cc b/shared/api/nemo_mel_spectrogram.cc new file mode 100644 index 000000000..18e33fcb0 --- /dev/null +++ b/shared/api/nemo_mel_spectrogram.cc @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// NeMo-compatible log-mel spectrogram extraction (Slaney scale, matching librosa/NeMo). +// No ONNX Runtime or other framework dependencies — pure C++ with standard library only. + +#include "nemo_mel_spectrogram.h" + +#include +#include +#include +#include + +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace nemo_mel { + +// Slaney mel scale constants + +static constexpr float kMinLogHz = 1000.0f; +static constexpr float kMinLogMel = 15.0f; // 1000 / (200/3) +static constexpr float kLinScale = 200.0f / 3.0f; // Hz per mel (linear region) +static constexpr float kLogStep = 0.06875177742094912f; // log(6.4) / 27 + +float HzToMel(float hz) { + if (hz < kMinLogHz) return hz / kLinScale; + return kMinLogMel + std::log(hz / kMinLogHz) / kLogStep; +} + +float MelToHz(float mel) { + if (mel < kMinLogMel) return mel * kLinScale; + return kMinLogHz * std::exp((mel - kMinLogMel) * kLogStep); +} + +std::vector> CreateMelFilterbank(int num_mels, int fft_size, int sample_rate) { + int num_bins = fft_size / 2 + 1; + float mel_low = HzToMel(0.0f); + float mel_high = HzToMel(static_cast(sample_rate) / 2.0f); + + // Compute mel center frequencies in Hz (num_mels + 2 points) + std::vector mel_f(num_mels + 2); + for (int i = 0; i < num_mels + 2; ++i) { + float m = mel_low + (mel_high - mel_low) * i / (num_mels + 1); + mel_f[i] = MelToHz(m); + } + + // Differences between consecutive mel center frequencies (Hz) + std::vector fdiff(num_mels + 1); + for (int i = 0; i < num_mels + 1; ++i) { + fdiff[i] = mel_f[i + 1] - mel_f[i]; + } + + // FFT bin center frequencies in Hz + std::vector fft_freqs(num_bins); + for (int k = 0; k < num_bins; ++k) { + fft_freqs[k] = static_cast(k) * sample_rate / fft_size; + } + + // Build triangular filterbank with Slaney normalization (matches librosa exactly) + std::vector> filterbank(num_mels, std::vector(num_bins, 0.0f)); + for (int m = 0; m < num_mels; ++m) { + for (int k = 0; k < num_bins; ++k) { + float lower = (fft_freqs[k] - mel_f[m]) / (fdiff[m] + 1e-10f); + float upper = (mel_f[m + 2] - fft_freqs[k]) / (fdiff[m + 1] + 1e-10f); + filterbank[m][k] = std::max(0.0f, std::min(lower, upper)); + } + // Slaney area normalization: 2 / bandwidth + float enorm = 2.0f / (mel_f[m + 2] - mel_f[m] + 1e-10f); + for (int k = 0; k < num_bins; ++k) { + filterbank[m][k] *= enorm; + } + } + return filterbank; +} + +void ComputeSTFTFrame(const float* frame, const float* window, int frame_len, + int fft_size, std::vector& magnitudes) { + int num_bins = fft_size / 2 + 1; + magnitudes.resize(num_bins); + + // Apply window and zero-pad to fft_size for FFT + dlib::matrix windowed(1, fft_size); + windowed = 0; + for (int n = 0; n < frame_len; ++n) { + windowed(0, n) = frame[n] * window[n]; + } + + // Real-valued FFT via dlib (O(N log N) instead of naive O(N²) DFT) + dlib::matrix> fft_result = dlib::fftr(windowed); + + // Power spectrum: |X[k]|² + for (int k = 0; k < num_bins; ++k) { + float re = fft_result(0, k).real(); + float im = fft_result(0, k).imag(); + magnitudes[k] = re * re + im * im; + } +} + +// BATCH LOG-MEL EXTRACTION +std::vector NemoComputeLogMelBatch(const float* audio, size_t num_samples, + const NemoMelConfig& cfg, int& out_num_frames) { + // Lazily-initialized statics are fine for batch mode (same config per process). + // If you need thread-safety with multiple configs, pass the filterbank in explicitly. + static auto mel_filters = CreateMelFilterbank(cfg.num_mels, cfg.fft_size, cfg.sample_rate); + static auto window = hann_window(cfg.win_length); + + int n = static_cast(num_samples); + + // Apply pre-emphasis: y[n] = x[n] - preemph * x[n-1] + std::vector preemphasized(n); + if (n > 0) { + preemphasized[0] = audio[0]; // No previous sample for first sample + for (int i = 1; i < n; ++i) { + preemphasized[i] = audio[i] - cfg.preemph * audio[i - 1]; + } + } + + // Center-pad both sides: fft_size/2 zeros on each side (matching torch.stft center=True) + int pad = cfg.fft_size / 2; + std::vector padded(pad + n + pad, 0.0f); + if (n > 0) { + std::memcpy(padded.data() + pad, preemphasized.data(), n * sizeof(float)); + } + + if (static_cast(padded.size()) < cfg.fft_size) { + padded.resize(cfg.fft_size, 0.0f); + } + + // Frame count using fft_size as frame size (matching torch.stft) + int num_frames = static_cast((padded.size() - cfg.fft_size) / cfg.hop_length) + 1; + out_num_frames = num_frames; + + int win_offset = (cfg.fft_size - cfg.win_length) / 2; + int num_bins = cfg.fft_size / 2 + 1; + std::vector magnitudes; + std::vector mel_spec(cfg.num_mels * num_frames); + + for (int t = 0; t < num_frames; ++t) { + const float* frame = padded.data() + t * cfg.hop_length + win_offset; + ComputeSTFTFrame(frame, window.data(), cfg.win_length, cfg.fft_size, magnitudes); + + for (int m = 0; m < cfg.num_mels; ++m) { + float val = 0.0f; + for (int k = 0; k < num_bins; ++k) { + val += mel_filters[m][k] * magnitudes[k]; + } + mel_spec[m * num_frames + t] = std::log(val + cfg.log_eps); + } + } + + return mel_spec; +} + +// STREAMING LOG-MEL EXTRACTION +NemoStreamingMelExtractor::NemoStreamingMelExtractor(const NemoMelConfig& cfg) + : cfg_(cfg) { + mel_filters_ = CreateMelFilterbank(cfg_.num_mels, cfg_.fft_size, cfg_.sample_rate); + hann_window_ = hann_window_symmetric(cfg_.win_length); + audio_overlap_.assign(cfg_.fft_size / 2, 0.0f); + preemph_last_sample_ = 0.0f; +} + +void NemoStreamingMelExtractor::Reset() { + audio_overlap_.assign(cfg_.fft_size / 2, 0.0f); + preemph_last_sample_ = 0.0f; +} + +std::pair, int> NemoStreamingMelExtractor::Process( + const float* audio, size_t num_samples) { + // Apply pre-emphasis filter: y[n] = x[n] - preemph * x[n-1] + std::vector preemphasized(num_samples); + if (num_samples > 0) { + preemphasized[0] = audio[0] - cfg_.preemph * preemph_last_sample_; + for (size_t i = 1; i < num_samples; ++i) { + preemphasized[i] = audio[i] - cfg_.preemph * audio[i - 1]; + } + preemph_last_sample_ = audio[num_samples - 1]; + } + + // Left-only center pad for streaming: prepend overlap from previous chunk. + // For the first chunk this is zeros (matching center=True left edge). + int pad = cfg_.fft_size / 2; + std::vector padded(pad + num_samples); + std::memcpy(padded.data(), audio_overlap_.data(), pad * sizeof(float)); + std::memcpy(padded.data() + pad, preemphasized.data(), num_samples * sizeof(float)); + + // Update overlap buffer for next chunk + if (num_samples >= static_cast(pad)) { + audio_overlap_.assign(preemphasized.data() + num_samples - pad, + preemphasized.data() + num_samples); + } else { + size_t keep = pad - num_samples; + std::vector new_overlap(pad, 0.0f); + std::memcpy(new_overlap.data(), audio_overlap_.data() + num_samples, keep * sizeof(float)); + std::memcpy(new_overlap.data() + keep, preemphasized.data(), num_samples * sizeof(float)); + audio_overlap_ = std::move(new_overlap); + } + + // Window centering offset (symmetric window smaller than fft_size) + int win_offset = (cfg_.fft_size - cfg_.win_length) / 2; // e.g. 56 + + // Right-pad to accommodate the window offset for the last frame + padded.resize(padded.size() + win_offset, 0.0f); + + if (static_cast(padded.size()) < win_offset + cfg_.win_length) { + padded.resize(win_offset + cfg_.win_length, 0.0f); + } + + // Frame count + int num_frames = static_cast((padded.size() - win_offset - cfg_.win_length) / cfg_.hop_length) + 1; + + int num_bins = cfg_.fft_size / 2 + 1; + std::vector mel_spec(cfg_.num_mels * num_frames); + + for (int t = 0; t < num_frames; ++t) { + const float* frame = padded.data() + t * cfg_.hop_length + win_offset; + + // FFT with symmetric Hann window (win_length samples, zero-padded to fft_size) + std::vector magnitudes; + ComputeSTFTFrame(frame, hann_window_.data(), cfg_.win_length, cfg_.fft_size, magnitudes); + + // Apply mel filterbank + log + for (int m = 0; m < cfg_.num_mels; ++m) { + float val = 0.0f; + for (int k = 0; k < num_bins; ++k) { + val += mel_filters_[m][k] * magnitudes[k]; + } + mel_spec[m * num_frames + t] = std::log(val + cfg_.log_eps); + } + } + + return {mel_spec, num_frames}; +} + +} // namespace nemo_mel diff --git a/shared/api/nemo_mel_spectrogram.h b/shared/api/nemo_mel_spectrogram.h new file mode 100644 index 000000000..e09cd6dbb --- /dev/null +++ b/shared/api/nemo_mel_spectrogram.h @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// NeMo-compatible log-mel spectrogram extraction (Slaney scale, matching librosa/NeMo). + +#pragma once + +#include +#include +#include + +namespace nemo_mel { + +struct NemoMelConfig { + int num_mels; + int fft_size; + int hop_length; + int win_length; + int sample_rate; + float preemph; + float log_eps; +}; + +// Mel scale conversions (Slaney) + +float HzToMel(float hz); +float MelToHz(float mel); + +/// Build a triangular mel filterbank with Slaney normalization (matches librosa). +/// Returns shape [num_mels][num_bins] where num_bins = fft_size/2 + 1. +std::vector> CreateMelFilterbank(int num_mels, int fft_size, int sample_rate); + +/// Compute |DFT|^2 (power spectrum) for a single windowed frame. +/// frame: pointer to fft_size samples (or win_length samples with window applied). +/// window: pointer to window coefficients (same length as frame_len). +/// frame_len: number of samples to read from frame and window. +/// fft_size: DFT size (output has fft_size/2 + 1 bins). +/// magnitudes: output power spectrum (resized to fft_size/2 + 1). +void ComputeSTFTFrame(const float* frame, const float* window, int frame_len, + int fft_size, std::vector& magnitudes); + + +// BATCH LOG-MEL EXTRACTION +/// Compute NeMo-compatible log-mel spectrogram for a complete audio buffer. +/// Applies pre-emphasis, center-pads both sides (fft_size/2 zeros), computes STFT +/// with a periodic Hann window, applies mel filterbank, and takes log(mel + eps). +/// +/// Output layout: row-major [num_mels, num_frames]. +/// out_num_frames is set to the number of time frames produced. +std::vector NemoComputeLogMelBatch(const float* audio, size_t num_samples, + const NemoMelConfig& cfg, int& out_num_frames); + +// STREAMING LOG-MEL EXTRACTION +/// Stateful streaming NeMo-compatible mel extractor that maintains overlap and +/// pre-emphasis state across successive audio chunks. +/// +/// Usage: +/// nemo_mel::NemoStreamingMelExtractor extractor(cfg); +/// auto [mel, frames] = extractor.Process(chunk1, n1); +/// auto [mel2, frames2] = extractor.Process(chunk2, n2); +/// extractor.Reset(); // new utterance +/// +class NemoStreamingMelExtractor { + public: + explicit NemoStreamingMelExtractor(const NemoMelConfig& cfg); + + /// Process one chunk of raw PCM audio (mono, float32). + /// Returns (mel_data, num_frames) where mel_data is row-major [num_mels, num_frames]. + std::pair, int> Process(const float* audio, size_t num_samples); + + /// Reset all streaming state for a new utterance. + void Reset(); + + const NemoMelConfig& config() const { return cfg_; } + + private: + NemoMelConfig cfg_; + std::vector> mel_filters_; + std::vector hann_window_; // symmetric, length = win_length + + // Streaming state + std::vector audio_overlap_; // last fft_size/2 pre-emphasized samples + float preemph_last_sample_{0.0f}; +}; + +} // namespace nemo_mel diff --git a/test/static_test/test_nemo_mel.cc b/test/static_test/test_nemo_mel.cc new file mode 100644 index 000000000..d84a41ce1 --- /dev/null +++ b/test/static_test/test_nemo_mel.cc @@ -0,0 +1,318 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "nemo_mel_spectrogram.h" +#include + +using namespace nemo_mel; + +static NemoMelConfig MakeTestConfig() { + NemoMelConfig cfg; + cfg.num_mels = 128; + cfg.fft_size = 512; + cfg.hop_length = 160; + cfg.win_length = 400; + cfg.sample_rate = 16000; + cfg.preemph = 0.97f; + cfg.log_eps = 5.96046448e-08f; + return cfg; +} + +// Generate a pure sine wave (mono, float32). +static std::vector SineWave(float freq_hz, float duration_sec, + int sample_rate = 16000, float amplitude = 0.5f) { + int n = static_cast(duration_sec * sample_rate); + std::vector wav(n); + const float two_pi = 2.0f * static_cast(M_PI); + for (int i = 0; i < n; ++i) { + wav[i] = amplitude * std::sin(two_pi * freq_hz * i / sample_rate); + } + return wav; +} + +// Mel scale conversions (Slaney) + +TEST(NemoMelTest, HzToMelLinearRegion) { + // Below 1000 Hz the Slaney scale is linear: mel = 3 * hz / 200 + EXPECT_FLOAT_EQ(HzToMel(0.0f), 0.0f); + EXPECT_FLOAT_EQ(HzToMel(200.0f), 3.0f); + EXPECT_FLOAT_EQ(HzToMel(1000.0f), 15.0f); +} + +TEST(NemoMelTest, HzToMelLogRegion) { + // Above 1000 Hz the Slaney scale is logarithmic + float mel_2000 = HzToMel(2000.0f); + float mel_4000 = HzToMel(4000.0f); + // mel(4000) - mel(2000) should equal mel(2000) - mel(1000) since the log region + // has equal spacing per octave + float diff_upper = mel_4000 - mel_2000; + float diff_lower = mel_2000 - HzToMel(1000.0f); + EXPECT_NEAR(diff_upper, diff_lower, 0.01f); +} + +TEST(NemoMelTest, MelToHzRoundTrip) { + // HzToMel and MelToHz should be inverses + for (float hz : {0.0f, 100.0f, 500.0f, 1000.0f, 2000.0f, 4000.0f, 8000.0f}) { + float mel = HzToMel(hz); + float hz_back = MelToHz(mel); + EXPECT_NEAR(hz_back, hz, 0.01f) << "Round-trip failed for hz=" << hz; + } +} + +// ─── Filterbank ───────────────────────────────────────────────────────────── + +TEST(NemoMelTest, FilterbankShape) { + auto fb = CreateMelFilterbank(128, 512, 16000); + ASSERT_EQ(fb.size(), 128u); + ASSERT_EQ(fb[0].size(), 257u); // fft_size/2 + 1 +} + +TEST(NemoMelTest, FilterbankNonNegative) { + auto fb = CreateMelFilterbank(128, 512, 16000); + for (const auto& row : fb) { + for (float v : row) { + EXPECT_GE(v, 0.0f); + } + } +} + +TEST(NemoMelTest, FilterbankTriangular) { + // Each mel filter should be triangular: has a single peak with values + // rising then falling, with no internal zeros between non-zero values. + auto fb = CreateMelFilterbank(64, 512, 16000); + for (size_t m = 0; m < fb.size(); ++m) { + const auto& row = fb[m]; + // Find first and last non-zero + int first_nz = -1, last_nz = -1; + for (int i = 0; i < static_cast(row.size()); ++i) { + if (row[i] > 0.0f) { + if (first_nz < 0) first_nz = i; + last_nz = i; + } + } + if (first_nz < 0) continue; // empty filter at edges is ok + // All values between first_nz and last_nz should be positive + for (int i = first_nz; i <= last_nz; ++i) { + EXPECT_GT(row[i], 0.0f) << "Zero gap in mel filter " << m << " at bin " << i; + } + } +} + +// ─── Window functions ─────────────────────────────────────────────────────── + +TEST(NemoMelTest, HannSymmetric) { + auto w = hann_window_symmetric(400); + ASSERT_EQ(w.size(), 400u); + // Endpoints should be 0 (symmetric Hann) + EXPECT_NEAR(w[0], 0.0f, 1e-7f); + EXPECT_NEAR(w[399], 0.0f, 1e-7f); + // Middle should be ~1.0 + EXPECT_NEAR(w[200], 1.0f, 0.01f); + // Should be symmetric + for (int i = 0; i < 200; ++i) { + EXPECT_NEAR(w[i], w[399 - i], 1e-6f) << "Asymmetry at index " << i; + } +} + +TEST(NemoMelTest, HannPeriodic) { + auto w = hann_window(400); + ASSERT_EQ(w.size(), 400u); + // Periodic window: first element is 0, last is non-zero + EXPECT_FLOAT_EQ(w[0], 0.0f); + EXPECT_GT(w[399], 0.0f); + // Window should have non-zero values + float sum = 0.0f; + for (int i = 0; i < 400; ++i) { + sum += w[i]; + } + EXPECT_GT(sum, 0.0f); +} + +// ─── STFT frame ───────────────────────────────────────────────────────────── + +TEST(NemoMelTest, STFTFrameDCSignal) { + // Constant signal: all energy should be in bin 0 + int fft_size = 512; + int win_length = 400; + std::vector frame(win_length, 1.0f); + auto window = hann_window(win_length); + std::vector magnitudes; + ComputeSTFTFrame(frame.data(), window.data(), win_length, fft_size, magnitudes); + ASSERT_EQ(magnitudes.size(), 257u); + // DC bin should have the largest magnitude + float dc = magnitudes[0]; + for (size_t i = 1; i < magnitudes.size(); ++i) { + EXPECT_LE(magnitudes[i], dc + 1e-4f) << "Non-DC bin " << i << " exceeds DC"; + } +} + +TEST(NemoMelTest, STFTFrameZeroSignal) { + int fft_size = 512; + int win_length = 400; + std::vector frame(win_length, 0.0f); + auto window = hann_window(win_length); + std::vector magnitudes; + ComputeSTFTFrame(frame.data(), window.data(), win_length, fft_size, magnitudes); + for (size_t i = 0; i < magnitudes.size(); ++i) { + EXPECT_NEAR(magnitudes[i], 0.0f, 1e-10f); + } +} + +// ─── Batch log-mel extraction ─────────────────────────────────────────────── + +TEST(NemoMelTest, BatchOutputShape) { + auto cfg = MakeTestConfig(); + auto wav = SineWave(440.0f, 0.5f); // 0.5 sec, 8000 samples + int num_frames = 0; + auto mel = NemoComputeLogMelBatch(wav.data(), wav.size(), cfg, num_frames); + + EXPECT_GT(num_frames, 0); + EXPECT_EQ(mel.size(), static_cast(cfg.num_mels) * num_frames); + + // Sanity-check frame count is in a reasonable range. + // Exact formula depends on center-padding strategy; just verify ballpark. + int min_expected = static_cast(wav.size()) / cfg.hop_length - 2; + int max_expected = static_cast(wav.size()) / cfg.hop_length + 5; + EXPECT_GE(num_frames, min_expected); + EXPECT_LE(num_frames, max_expected); +} + +TEST(NemoMelTest, BatchSilenceOutput) { + // Silence should produce very low (near log_eps) mel values + auto cfg = MakeTestConfig(); + std::vector silence(16000, 0.0f); // 1 sec + int num_frames = 0; + auto mel = NemoComputeLogMelBatch(silence.data(), silence.size(), cfg, num_frames); + + float expected_log_eps = std::log(cfg.log_eps); + for (size_t i = 0; i < mel.size(); ++i) { + EXPECT_NEAR(mel[i], expected_log_eps, 0.1f) + << "Silence mel value at index " << i << " deviates from log(eps)"; + } +} + +TEST(NemoMelTest, BatchDeterministic) { + auto cfg = MakeTestConfig(); + auto wav = SineWave(1000.0f, 0.3f); + int nf1 = 0, nf2 = 0; + auto mel1 = NemoComputeLogMelBatch(wav.data(), wav.size(), cfg, nf1); + auto mel2 = NemoComputeLogMelBatch(wav.data(), wav.size(), cfg, nf2); + ASSERT_EQ(nf1, nf2); + ASSERT_EQ(mel1.size(), mel2.size()); + for (size_t i = 0; i < mel1.size(); ++i) { + EXPECT_FLOAT_EQ(mel1[i], mel2[i]); + } +} + +TEST(NemoMelTest, BatchSineEnergy) { + // A 440Hz sine should concentrate energy in lower mel bands + auto cfg = MakeTestConfig(); + auto wav = SineWave(440.0f, 0.5f); + int num_frames = 0; + auto mel = NemoComputeLogMelBatch(wav.data(), wav.size(), cfg, num_frames); + + // Average mel energy across time for each band + std::vector band_avg(cfg.num_mels, 0.0f); + for (int m = 0; m < cfg.num_mels; ++m) { + for (int t = 0; t < num_frames; ++t) { + band_avg[m] += mel[m * num_frames + t]; + } + band_avg[m] /= num_frames; + } + // The mel band containing 440 Hz should have more energy than the highest band + // 440 Hz at 16kHz with 128 mels is in the low mel range + float max_low = *std::max_element(band_avg.begin(), band_avg.begin() + 30); + float avg_high = 0.0f; + for (int m = 100; m < 128; ++m) avg_high += band_avg[m]; + avg_high /= 28.0f; + EXPECT_GT(max_low, avg_high); +} + +// ─── Streaming extractor ──────────────────────────────────────────────────── + +TEST(NemoMelTest, StreamingSingleChunkMatchesBatch) { + auto cfg = MakeTestConfig(); + auto wav = SineWave(440.0f, 0.5f); // 8000 samples + + // Batch reference + int batch_frames = 0; + auto batch_mel = NemoComputeLogMelBatch(wav.data(), wav.size(), cfg, batch_frames); + + // Streaming: send all audio in one chunk + NemoStreamingMelExtractor extractor(cfg); + auto [stream_mel, stream_frames] = extractor.Process(wav.data(), wav.size()); + + // Streaming uses symmetric Hann + left-only center-pad vs batch uses periodic Hann + // + both-side center-pad, so frame counts may differ by a small amount. + EXPECT_NEAR(stream_frames, batch_frames, 2); + // Both should produce non-empty output + EXPECT_GT(stream_frames, 0); + EXPECT_GT(stream_mel.size(), 0u); +} + +TEST(NemoMelTest, StreamingMultiChunk) { + auto cfg = MakeTestConfig(); + auto wav = SineWave(440.0f, 1.0f); // 16000 samples + + NemoStreamingMelExtractor extractor(cfg); + int total_frames = 0; + std::vector all_mel; + + // Feed in 4 chunks of 4000 samples + size_t chunk_size = 4000; + for (size_t offset = 0; offset < wav.size(); offset += chunk_size) { + size_t n = std::min(chunk_size, wav.size() - offset); + auto [mel, frames] = extractor.Process(wav.data() + offset, n); + all_mel.insert(all_mel.end(), mel.begin(), mel.end()); + total_frames += frames; + } + + EXPECT_GT(total_frames, 0); + EXPECT_EQ(all_mel.size(), static_cast(cfg.num_mels) * total_frames); +} + +TEST(NemoMelTest, StreamingReset) { + auto cfg = MakeTestConfig(); + auto wav = SineWave(440.0f, 0.3f); + + NemoStreamingMelExtractor extractor(cfg); + + // First utterance + auto [mel1, nf1] = extractor.Process(wav.data(), wav.size()); + + // Reset and process same audio + extractor.Reset(); + auto [mel2, nf2] = extractor.Process(wav.data(), wav.size()); + + // Should produce identical results after reset + ASSERT_EQ(nf1, nf2); + ASSERT_EQ(mel1.size(), mel2.size()); + for (size_t i = 0; i < mel1.size(); ++i) { + EXPECT_FLOAT_EQ(mel1[i], mel2[i]) << "Mismatch after reset at index " << i; + } +} + +TEST(NemoMelTest, StreamingEmptyChunk) { + auto cfg = MakeTestConfig(); + NemoStreamingMelExtractor extractor(cfg); + // The first Process() call may produce frames from the initial center-pad overlap, + // even with 0 input samples. Just verify it doesn't crash. + auto [mel, frames] = extractor.Process(nullptr, 0); + EXPECT_GE(frames, 0); + EXPECT_EQ(mel.size(), static_cast(cfg.num_mels) * frames); +} + +TEST(NemoMelTest, StreamingSmallChunk) { + // Chunk smaller than hop_length should produce 0 frames (buffered for next call) + auto cfg = MakeTestConfig(); + NemoStreamingMelExtractor extractor(cfg); + std::vector tiny(100, 0.1f); // 100 samples < hop_length (160) + auto [mel, frames] = extractor.Process(tiny.data(), tiny.size()); + // May produce 0 frames since not enough samples for a full hop + EXPECT_GE(frames, 0); +}