Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,12 @@ target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})

file(GLOB _TARGET_LIB_SRC "shared/lib/*.cc")

if(OCOS_ENABLE_AUDIO)
file(GLOB nemo_mel_SRC "shared/api/nemo_mel_*")
list(APPEND _TARGET_LIB_SRC ${nemo_mel_SRC})
endif()

if(OCOS_ENABLE_C_API)
file(GLOB utils_TARGET_SRC "shared/api/c_api_utils.*" "shared/api/runner.hpp")
list(APPEND _TARGET_LIB_SRC ${utils_TARGET_SRC})
Expand Down Expand Up @@ -874,6 +880,7 @@ endif()
target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "${PROJECT_SOURCE_DIR}/shared/api")

target_link_libraries(ortcustomops PUBLIC ocos_operators)

Expand Down
12 changes: 12 additions & 0 deletions operators/math/dlib/stft_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,17 @@ struct StftNormal {
window[n] = static_cast<float>(n_sin * n_sin);
}

return window;
}

// Symmetric Hann window: matches torch.hann_window(N, periodic=False).
// Uses the classic cosine formula with denominator (N-1).
static std::vector<float> hann_window_symmetric(int N) {
std::vector<float> window(N);

for (int n = 0; n < N; ++n) {
window[n] = 0.5f * (1.0f - std::cos(2.0f * static_cast<float>(M_PI) * n / (N - 1)));
}

return window;
}
240 changes: 240 additions & 0 deletions shared/api/nemo_mel_spectrogram.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// NeMo-compatible log-mel spectrogram extraction (Slaney scale, matching librosa/NeMo).
// No ONNX Runtime or other framework dependencies — pure C++ with standard library only.

#include "nemo_mel_spectrogram.h"

#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>

#include <dlib/matrix.h>
#include <math/dlib/stft_norm.hpp>

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

namespace nemo_mel {

// Slaney mel scale constants

static constexpr float kMinLogHz = 1000.0f;
static constexpr float kMinLogMel = 15.0f; // 1000 / (200/3)
static constexpr float kLinScale = 200.0f / 3.0f; // Hz per mel (linear region)
static constexpr float kLogStep = 0.06875177742094912f; // log(6.4) / 27

float HzToMel(float hz) {
if (hz < kMinLogHz) return hz / kLinScale;
return kMinLogMel + std::log(hz / kMinLogHz) / kLogStep;
}

float MelToHz(float mel) {
if (mel < kMinLogMel) return mel * kLinScale;
return kMinLogHz * std::exp((mel - kMinLogMel) * kLogStep);
}

std::vector<std::vector<float>> CreateMelFilterbank(int num_mels, int fft_size, int sample_rate) {
int num_bins = fft_size / 2 + 1;
float mel_low = HzToMel(0.0f);
float mel_high = HzToMel(static_cast<float>(sample_rate) / 2.0f);

// Compute mel center frequencies in Hz (num_mels + 2 points)
std::vector<float> mel_f(num_mels + 2);
for (int i = 0; i < num_mels + 2; ++i) {
float m = mel_low + (mel_high - mel_low) * i / (num_mels + 1);
mel_f[i] = MelToHz(m);
}

// Differences between consecutive mel center frequencies (Hz)
std::vector<float> fdiff(num_mels + 1);
for (int i = 0; i < num_mels + 1; ++i) {
fdiff[i] = mel_f[i + 1] - mel_f[i];
}

// FFT bin center frequencies in Hz
std::vector<float> fft_freqs(num_bins);
for (int k = 0; k < num_bins; ++k) {
fft_freqs[k] = static_cast<float>(k) * sample_rate / fft_size;
}

// Build triangular filterbank with Slaney normalization (matches librosa exactly)
std::vector<std::vector<float>> filterbank(num_mels, std::vector<float>(num_bins, 0.0f));
for (int m = 0; m < num_mels; ++m) {
for (int k = 0; k < num_bins; ++k) {
float lower = (fft_freqs[k] - mel_f[m]) / (fdiff[m] + 1e-10f);
float upper = (mel_f[m + 2] - fft_freqs[k]) / (fdiff[m + 1] + 1e-10f);
filterbank[m][k] = std::max(0.0f, std::min(lower, upper));
}
// Slaney area normalization: 2 / bandwidth
float enorm = 2.0f / (mel_f[m + 2] - mel_f[m] + 1e-10f);
for (int k = 0; k < num_bins; ++k) {
filterbank[m][k] *= enorm;
}
}
return filterbank;
}

void ComputeSTFTFrame(const float* frame, const float* window, int frame_len,
int fft_size, std::vector<float>& magnitudes) {
int num_bins = fft_size / 2 + 1;
magnitudes.resize(num_bins);

// Apply window and zero-pad to fft_size for FFT
dlib::matrix<float, 1, 0> windowed(1, fft_size);
windowed = 0;
for (int n = 0; n < frame_len; ++n) {
windowed(0, n) = frame[n] * window[n];
}

// Real-valued FFT via dlib (O(N log N) instead of naive O(N²) DFT)
dlib::matrix<std::complex<float>> fft_result = dlib::fftr(windowed);

// Power spectrum: |X[k]|²
for (int k = 0; k < num_bins; ++k) {
float re = fft_result(0, k).real();
float im = fft_result(0, k).imag();
magnitudes[k] = re * re + im * im;
}
}

// BATCH LOG-MEL EXTRACTION
std::vector<float> NemoComputeLogMelBatch(const float* audio, size_t num_samples,
const NemoMelConfig& cfg, int& out_num_frames) {
// Lazily-initialized statics are fine for batch mode (same config per process).
// If you need thread-safety with multiple configs, pass the filterbank in explicitly.
static auto mel_filters = CreateMelFilterbank(cfg.num_mels, cfg.fft_size, cfg.sample_rate);
static auto window = hann_window(cfg.win_length);

int n = static_cast<int>(num_samples);

// Apply pre-emphasis: y[n] = x[n] - preemph * x[n-1]
std::vector<float> preemphasized(n);
if (n > 0) {
preemphasized[0] = audio[0]; // No previous sample for first sample
for (int i = 1; i < n; ++i) {
preemphasized[i] = audio[i] - cfg.preemph * audio[i - 1];
}
}

// Center-pad both sides: fft_size/2 zeros on each side (matching torch.stft center=True)
int pad = cfg.fft_size / 2;
std::vector<float> padded(pad + n + pad, 0.0f);
if (n > 0) {
std::memcpy(padded.data() + pad, preemphasized.data(), n * sizeof(float));
}

if (static_cast<int>(padded.size()) < cfg.fft_size) {
padded.resize(cfg.fft_size, 0.0f);
}

// Frame count using fft_size as frame size (matching torch.stft)
int num_frames = static_cast<int>((padded.size() - cfg.fft_size) / cfg.hop_length) + 1;
out_num_frames = num_frames;

int win_offset = (cfg.fft_size - cfg.win_length) / 2;
int num_bins = cfg.fft_size / 2 + 1;
std::vector<float> magnitudes;
std::vector<float> mel_spec(cfg.num_mels * num_frames);

for (int t = 0; t < num_frames; ++t) {
const float* frame = padded.data() + t * cfg.hop_length + win_offset;
ComputeSTFTFrame(frame, window.data(), cfg.win_length, cfg.fft_size, magnitudes);

for (int m = 0; m < cfg.num_mels; ++m) {
float val = 0.0f;
for (int k = 0; k < num_bins; ++k) {
val += mel_filters[m][k] * magnitudes[k];
}
mel_spec[m * num_frames + t] = std::log(val + cfg.log_eps);
}
}

return mel_spec;
}

// STREAMING LOG-MEL EXTRACTION
NemoStreamingMelExtractor::NemoStreamingMelExtractor(const NemoMelConfig& cfg)
: cfg_(cfg) {
mel_filters_ = CreateMelFilterbank(cfg_.num_mels, cfg_.fft_size, cfg_.sample_rate);
hann_window_ = hann_window_symmetric(cfg_.win_length);
audio_overlap_.assign(cfg_.fft_size / 2, 0.0f);
preemph_last_sample_ = 0.0f;
}

void NemoStreamingMelExtractor::Reset() {
audio_overlap_.assign(cfg_.fft_size / 2, 0.0f);
preemph_last_sample_ = 0.0f;
}

std::pair<std::vector<float>, int> NemoStreamingMelExtractor::Process(
const float* audio, size_t num_samples) {
// Apply pre-emphasis filter: y[n] = x[n] - preemph * x[n-1]
std::vector<float> preemphasized(num_samples);
if (num_samples > 0) {
preemphasized[0] = audio[0] - cfg_.preemph * preemph_last_sample_;
for (size_t i = 1; i < num_samples; ++i) {
preemphasized[i] = audio[i] - cfg_.preemph * audio[i - 1];
}
preemph_last_sample_ = audio[num_samples - 1];
}

// Left-only center pad for streaming: prepend overlap from previous chunk.
// For the first chunk this is zeros (matching center=True left edge).
int pad = cfg_.fft_size / 2;
std::vector<float> padded(pad + num_samples);
std::memcpy(padded.data(), audio_overlap_.data(), pad * sizeof(float));
std::memcpy(padded.data() + pad, preemphasized.data(), num_samples * sizeof(float));

// Update overlap buffer for next chunk
if (num_samples >= static_cast<size_t>(pad)) {
audio_overlap_.assign(preemphasized.data() + num_samples - pad,
preemphasized.data() + num_samples);
} else {
size_t keep = pad - num_samples;
std::vector<float> new_overlap(pad, 0.0f);
std::memcpy(new_overlap.data(), audio_overlap_.data() + num_samples, keep * sizeof(float));
std::memcpy(new_overlap.data() + keep, preemphasized.data(), num_samples * sizeof(float));
audio_overlap_ = std::move(new_overlap);
}

// Window centering offset (symmetric window smaller than fft_size)
int win_offset = (cfg_.fft_size - cfg_.win_length) / 2; // e.g. 56

// Right-pad to accommodate the window offset for the last frame
padded.resize(padded.size() + win_offset, 0.0f);

if (static_cast<int>(padded.size()) < win_offset + cfg_.win_length) {
padded.resize(win_offset + cfg_.win_length, 0.0f);
}

// Frame count
int num_frames = static_cast<int>((padded.size() - win_offset - cfg_.win_length) / cfg_.hop_length) + 1;

int num_bins = cfg_.fft_size / 2 + 1;
std::vector<float> mel_spec(cfg_.num_mels * num_frames);

for (int t = 0; t < num_frames; ++t) {
const float* frame = padded.data() + t * cfg_.hop_length + win_offset;

// FFT with symmetric Hann window (win_length samples, zero-padded to fft_size)
std::vector<float> magnitudes;
ComputeSTFTFrame(frame, hann_window_.data(), cfg_.win_length, cfg_.fft_size, magnitudes);

// Apply mel filterbank + log
for (int m = 0; m < cfg_.num_mels; ++m) {
float val = 0.0f;
for (int k = 0; k < num_bins; ++k) {
val += mel_filters_[m][k] * magnitudes[k];
}
mel_spec[m * num_frames + t] = std::log(val + cfg_.log_eps);
}
}

return {mel_spec, num_frames};
}

} // namespace nemo_mel
86 changes: 86 additions & 0 deletions shared/api/nemo_mel_spectrogram.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//
// NeMo-compatible log-mel spectrogram extraction (Slaney scale, matching librosa/NeMo).

#pragma once

#include <cstddef>
#include <utility>
#include <vector>

namespace nemo_mel {

struct NemoMelConfig {
int num_mels;
int fft_size;
int hop_length;
int win_length;
int sample_rate;
float preemph;
float log_eps;
};

// Mel scale conversions (Slaney)

float HzToMel(float hz);
float MelToHz(float mel);

/// Build a triangular mel filterbank with Slaney normalization (matches librosa).
/// Returns shape [num_mels][num_bins] where num_bins = fft_size/2 + 1.
std::vector<std::vector<float>> CreateMelFilterbank(int num_mels, int fft_size, int sample_rate);

/// Compute |DFT|^2 (power spectrum) for a single windowed frame.
/// frame: pointer to fft_size samples (or win_length samples with window applied).
/// window: pointer to window coefficients (same length as frame_len).
/// frame_len: number of samples to read from frame and window.
/// fft_size: DFT size (output has fft_size/2 + 1 bins).
/// magnitudes: output power spectrum (resized to fft_size/2 + 1).
void ComputeSTFTFrame(const float* frame, const float* window, int frame_len,
int fft_size, std::vector<float>& magnitudes);


// BATCH LOG-MEL EXTRACTION
/// Compute NeMo-compatible log-mel spectrogram for a complete audio buffer.
/// Applies pre-emphasis, center-pads both sides (fft_size/2 zeros), computes STFT
/// with a periodic Hann window, applies mel filterbank, and takes log(mel + eps).
///
/// Output layout: row-major [num_mels, num_frames].
/// out_num_frames is set to the number of time frames produced.
std::vector<float> NemoComputeLogMelBatch(const float* audio, size_t num_samples,
const NemoMelConfig& cfg, int& out_num_frames);

// STREAMING LOG-MEL EXTRACTION
/// Stateful streaming NeMo-compatible mel extractor that maintains overlap and
/// pre-emphasis state across successive audio chunks.
///
/// Usage:
/// nemo_mel::NemoStreamingMelExtractor extractor(cfg);
/// auto [mel, frames] = extractor.Process(chunk1, n1);
/// auto [mel2, frames2] = extractor.Process(chunk2, n2);
/// extractor.Reset(); // new utterance
///
class NemoStreamingMelExtractor {
public:
explicit NemoStreamingMelExtractor(const NemoMelConfig& cfg);

/// Process one chunk of raw PCM audio (mono, float32).
/// Returns (mel_data, num_frames) where mel_data is row-major [num_mels, num_frames].
std::pair<std::vector<float>, int> Process(const float* audio, size_t num_samples);

/// Reset all streaming state for a new utterance.
void Reset();

const NemoMelConfig& config() const { return cfg_; }

private:
NemoMelConfig cfg_;
std::vector<std::vector<float>> mel_filters_;
std::vector<float> hann_window_; // symmetric, length = win_length

// Streaming state
std::vector<float> audio_overlap_; // last fft_size/2 pre-emphasized samples
float preemph_last_sample_{0.0f};
};

} // namespace nemo_mel
Loading
Loading