Skip to content

Commit a919f8e

Browse files
committed
improve audio result accuracy
1 parent 10ca3fd commit a919f8e

File tree

3 files changed

+121
-17
lines changed

3 files changed

+121
-17
lines changed

shared/api/speech_features.hpp

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,99 @@ class SpeechFeatures {
5151
return stft_norm_.Compute(pcm, n_fft_, hop_length_, {fft_win_.data(), fft_win_.size()}, n_fft_, stft_norm);
5252
}
5353

54+
OrtxStatus SpeechLibSTFTNorm(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& stft_norm) {
55+
const float preemphasis = 0.97f;
56+
// # Spec 1: SpeechLib cut remaining sample insufficient for a hop
57+
// n_batch = (wav.shape[0] - win_length) // hop_length + 1
58+
auto pcm_length = pcm.Shape()[1];
59+
auto n_batch = (pcm_length - frame_length_) / hop_length_ + 1;
60+
auto pcm_data = pcm.Data();
61+
dlib::matrix<float> dm_x = dlib::mat(pcm_data, 1, pcm_length);
62+
63+
// # Here we don't use stride_tricks since the input array may not satisfy
64+
// # memory layout requirement and we need writeable output
65+
// # Here we only use list of views before copy to desination
66+
// # so it is more efficient than broadcasting
67+
// y_frames = np.array(
68+
// [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
69+
// dtype=np.float32,
70+
// )
71+
72+
// # Spec 2: SpeechLib applies preemphasis within each batch
73+
// y_frames_prev = np.roll(y_frames, 1, axis=1)
74+
// y_frames_prev[:, 0] = y_frames_prev[:, 1]
75+
// y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
76+
// S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
77+
78+
// Step 1: Create y_frames_prev by rolling each row right by 1 and adjusting the first element
79+
dlib::matrix<float> y_frames_prev(n_batch, frame_length_);
80+
for (long r = 0; r < n_batch; ++r) {
81+
for (long c = 0; c < frame_length_; ++c) {
82+
if (c == 0) {
83+
y_frames_prev(r, c) = dm_x(0, r * hop_length_ + frame_length_ - 1);
84+
} else {
85+
y_frames_prev(r, c) = dm_x(0, r * hop_length_ + c - 1);
86+
}
87+
}
88+
y_frames_prev(r, 0) = y_frames_prev(r, 1);
89+
}
90+
91+
// Step 2: Apply pre-emphasis and scale by 32768
92+
dlib::matrix<float> y_processed(n_batch, frame_length_);
93+
for (long r = 0; r < n_batch; ++r) {
94+
for (long c = 0; c < frame_length_; ++c) {
95+
float sample = dm_x(0, r * hop_length_ + c);
96+
float rolled_sample = y_frames_prev(r, c);
97+
y_processed(r, c) = (sample - preemphasis * rolled_sample) * 32768.0f;
98+
}
99+
}
100+
101+
// Step 3: Apply FFT window to each frame
102+
for (long r = 0; r < n_batch; ++r) {
103+
for (long c = 0; c < frame_length_; ++c) {
104+
y_processed(r, c) *= fft_win_[c];
105+
}
106+
}
107+
108+
// Step 4: Compute Full FFT for each frame (complex output)
109+
// add extra column for simulating STFTNorm output shape
110+
dlib::matrix<std::complex<float>> S(n_batch + 1, n_fft_ / 2 + 1); // Use n_fft_ columns instead of n_rfft
111+
112+
for (long r = 0; r < n_batch; ++r) {
113+
dlib::matrix<float, 1, 0> frame = rowm(y_processed, r);
114+
dlib::matrix<float, 1, 0> padded_frame(1, n_fft_);
115+
padded_frame = 0;
116+
117+
long copy_length = (std::min)(frame_length_, n_fft_);
118+
dlib::set_subm(padded_frame, 0, 0, 1, copy_length) = dlib::subm(frame, 0, 0, 1, copy_length);
119+
120+
// Convert real-valued frame to complex for full FFT
121+
dlib::matrix<std::complex<float>> padded_frame_complex(1, n_fft_);
122+
for (long j = 0; j < n_fft_; ++j) {
123+
padded_frame_complex(0, j) = std::complex<float>(padded_frame(0, j), 0.0f);
124+
}
125+
126+
// Compute full FFT (complex output)
127+
dlib::matrix<std::complex<float>> fft_result = dlib::fft(padded_frame_complex);
128+
129+
// Store result (n_fft_ complex values)
130+
for (long c = 0; c <= n_fft_ / 2; ++c) {
131+
S(r, c) = fft_result(0, c);
132+
}
133+
}
134+
135+
// Compute spectral power (squared magnitude)
136+
auto S_norm = dlib::norm(S);
137+
dlib::matrix<float> spec_power = dlib::trans(S_norm);
138+
139+
std::vector<int64_t> outdim{1, spec_power.nr(), spec_power.nc()};
140+
auto result_size = spec_power.size();
141+
auto out0 = stft_norm.Allocate(outdim);
142+
memcpy(out0, spec_power.steal_memory().get(), result_size * sizeof(float));
143+
144+
return {};
145+
}
146+
54147
static std::vector<float> hann_window(int N) {
55148
std::vector<float> window(N);
56149

@@ -137,10 +230,20 @@ class LogMel {
137230
*/
138231
assert(stft_norm.Shape().size() == 3 && stft_norm.Shape()[0] == 1);
139232
std::vector<int64_t> stft_shape = stft_norm.Shape();
140-
dlib::matrix<float> magnitudes(stft_norm.Shape()[1], stft_norm.Shape()[2] - 1);
233+
int64_t n_fill_zero_col = stft_shape[1]; // if 8k, fill 4k - 8k hz with zeros
234+
int64_t additional_row = 0;
235+
if (mel_filters_.nc() > stft_shape[1]) {
236+
n_fill_zero_col = mel_filters_.nc() - stft_shape[1] - 1;
237+
additional_row = stft_shape[1] - 1;
238+
}
239+
dlib::matrix<float> magnitudes(stft_shape[1] + additional_row, stft_shape[2] - 1);
141240
for (int i = 0; i < magnitudes.nr(); ++i) {
142-
std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1,
143-
magnitudes.begin() + i * magnitudes.nc());
241+
if (i < n_fill_zero_col) {
242+
std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1,
243+
magnitudes.begin() + i * magnitudes.nc());
244+
} else {
245+
std::fill(magnitudes.begin() + i * magnitudes.nc(), magnitudes.begin() + (i + 1) * magnitudes.nc(), 0.0f);
246+
}
144247
}
145248

146249
dlib::matrix<float> mel_spec = mel_filters_ * magnitudes;
@@ -191,7 +294,8 @@ class LogMel {
191294
}
192295

193296
// Function to compute the Mel filterbank
194-
static dlib::matrix<float> MelFilterBank(int n_fft, int n_mels, int sr = 16000, float min_mel = 0,
297+
static dlib::matrix<float> MelFilterBank(int n_fft, int n_mels,
298+
int sr = 16000, float min_mel = 0,
195299
float max_mel = 45.245640471924965) {
196300
// Initialize the filterbank matrix
197301
dlib::matrix<float> fbank(n_mels, n_fft / 2 + 1);
@@ -312,7 +416,7 @@ class Phi4AudioEmbed {
312416
ortc::Tensor<float> stft_norm(&CppAllocator::Instance());
313417
SpeechFeatures stft_normal;
314418
stft_normal.Init(sr_val == 8000? stft_normal_8k_attrs_: stft_normal_attrs_);
315-
auto status = stft_normal.STFTNorm(pcm, stft_norm);
419+
auto status = stft_normal.SpeechLibSTFTNorm(pcm, stft_norm);
316420
if (!status.IsOk()) {
317421
return status;
318422
}

test/data/models/phi-4/audio_feature_extraction.json

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,26 @@
1818
"type": "Phi4AudioEmbed",
1919
"attrs": {
2020
"audio_compression_rate": 8,
21+
"stft_normal/n_fft": 512,
22+
"stft_normal/frame_length": 400,
23+
"stft_normal/hop_length": 160,
24+
"stft_normal/win_fn": "hamming",
2125
"logmel/chunk_size": 30,
2226
"logmel/hop_length": 160,
2327
"logmel/n_fft": 512,
2428
"logmel/n_mel": 80,
2529
"logmel/feature_first": 0,
2630
"logmel/no_padding": 1,
27-
"stft_normal/n_fft": 512,
28-
"stft_normal/frame_length": 400,
29-
"stft_normal/hop_length": 160,
30-
"stft_normal/win_fn": "hamming",
31+
"stft_normal_8k/n_fft": 256,
32+
"stft_normal_8k/frame_length": 200,
33+
"stft_normal_8k/hop_length": 80,
34+
"stft_normal_8k/win_fn": "hamming",
3135
"logmel_8k/chunk_size": 30,
3236
"logmel_8k/hop_length": 80,
33-
"logmel_8k/n_fft": 256,
37+
"logmel_8k/n_fft": 512,
3438
"logmel_8k/n_mel": 80,
3539
"logmel_8k/feature_first": 0,
36-
"logmel_8k/no_padding": 1,
37-
"stft_normal_8k/n_fft": 256,
38-
"stft_normal_8k/frame_length": 200,
39-
"stft_normal_8k/hop_length": 160,
40-
"stft_normal_8k/win_fn": "hamming"
40+
"logmel_8k/no_padding": 1
4141
}
4242
}
4343
}

test/pp_api_test/test_feature_extraction.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ TEST(ExtractorTest, TestPhi4AudioFeatureExtraction) {
5959
size_t num_dims;
6060
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
6161
ASSERT_EQ(err, kOrtxOK);
62-
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 1346, 80}));
62+
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 1344, 80}));
6363

6464
tensor.reset();
6565
err = OrtxTensorResultGetAt(result.get(), 1, tensor.ToBeAssigned());
@@ -90,7 +90,7 @@ TEST(ExtractorTest, TestPhi4AudioFeatureExtraction8k) {
9090
size_t num_dims;
9191
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
9292
ASSERT_EQ(err, kOrtxOK);
93-
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({1, 1470, 80}));
93+
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({1, 2938, 80}));
9494

9595
tensor.reset();
9696
err = OrtxTensorResultGetAt(result.get(), 1, tensor.ToBeAssigned());

0 commit comments

Comments
 (0)