@@ -51,6 +51,99 @@ class SpeechFeatures {
5151 return stft_norm_.Compute (pcm, n_fft_, hop_length_, {fft_win_.data (), fft_win_.size ()}, n_fft_, stft_norm);
5252 }
5353
54+ OrtxStatus SpeechLibSTFTNorm (const ortc::Tensor<float >& pcm, ortc::Tensor<float >& stft_norm) {
55+ const float preemphasis = 0 .97f ;
56+ // # Spec 1: SpeechLib cut remaining sample insufficient for a hop
57+ // n_batch = (wav.shape[0] - win_length) // hop_length + 1
58+ auto pcm_length = pcm.Shape ()[1 ];
59+ auto n_batch = (pcm_length - frame_length_) / hop_length_ + 1 ;
60+ auto pcm_data = pcm.Data ();
61+ dlib::matrix<float > dm_x = dlib::mat (pcm_data, 1 , pcm_length);
62+
63+ // # Here we don't use stride_tricks since the input array may not satisfy
64+ // # memory layout requirement and we need writeable output
65+ // # Here we only use list of views before copy to desination
66+ // # so it is more efficient than broadcasting
67+ // y_frames = np.array(
68+ // [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
69+ // dtype=np.float32,
70+ // )
71+
72+ // # Spec 2: SpeechLib applies preemphasis within each batch
73+ // y_frames_prev = np.roll(y_frames, 1, axis=1)
74+ // y_frames_prev[:, 0] = y_frames_prev[:, 1]
75+ // y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
76+ // S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
77+
78+ // Step 1: Create y_frames_prev by rolling each row right by 1 and adjusting the first element
79+ dlib::matrix<float > y_frames_prev (n_batch, frame_length_);
80+ for (long r = 0 ; r < n_batch; ++r) {
81+ for (long c = 0 ; c < frame_length_; ++c) {
82+ if (c == 0 ) {
83+ y_frames_prev (r, c) = dm_x (0 , r * hop_length_ + frame_length_ - 1 );
84+ } else {
85+ y_frames_prev (r, c) = dm_x (0 , r * hop_length_ + c - 1 );
86+ }
87+ }
88+ y_frames_prev (r, 0 ) = y_frames_prev (r, 1 );
89+ }
90+
91+ // Step 2: Apply pre-emphasis and scale by 32768
92+ dlib::matrix<float > y_processed (n_batch, frame_length_);
93+ for (long r = 0 ; r < n_batch; ++r) {
94+ for (long c = 0 ; c < frame_length_; ++c) {
95+ float sample = dm_x (0 , r * hop_length_ + c);
96+ float rolled_sample = y_frames_prev (r, c);
97+ y_processed (r, c) = (sample - preemphasis * rolled_sample) * 32768 .0f ;
98+ }
99+ }
100+
101+ // Step 3: Apply FFT window to each frame
102+ for (long r = 0 ; r < n_batch; ++r) {
103+ for (long c = 0 ; c < frame_length_; ++c) {
104+ y_processed (r, c) *= fft_win_[c];
105+ }
106+ }
107+
108+ // Step 4: Compute Full FFT for each frame (complex output)
109+ // add extra column for simulating STFTNorm output shape
110+ dlib::matrix<std::complex <float >> S (n_batch + 1 , n_fft_ / 2 + 1 ); // Use n_fft_ columns instead of n_rfft
111+
112+ for (long r = 0 ; r < n_batch; ++r) {
113+ dlib::matrix<float , 1 , 0 > frame = rowm (y_processed, r);
114+ dlib::matrix<float , 1 , 0 > padded_frame (1 , n_fft_);
115+ padded_frame = 0 ;
116+
117+ long copy_length = (std::min)(frame_length_, n_fft_);
118+ dlib::set_subm (padded_frame, 0 , 0 , 1 , copy_length) = dlib::subm (frame, 0 , 0 , 1 , copy_length);
119+
120+ // Convert real-valued frame to complex for full FFT
121+ dlib::matrix<std::complex <float >> padded_frame_complex (1 , n_fft_);
122+ for (long j = 0 ; j < n_fft_; ++j) {
123+ padded_frame_complex (0 , j) = std::complex <float >(padded_frame (0 , j), 0 .0f );
124+ }
125+
126+ // Compute full FFT (complex output)
127+ dlib::matrix<std::complex <float >> fft_result = dlib::fft (padded_frame_complex);
128+
129+ // Store result (n_fft_ complex values)
130+ for (long c = 0 ; c <= n_fft_ / 2 ; ++c) {
131+ S (r, c) = fft_result (0 , c);
132+ }
133+ }
134+
135+ // Compute spectral power (squared magnitude)
136+ auto S_norm = dlib::norm (S);
137+ dlib::matrix<float > spec_power = dlib::trans (S_norm);
138+
139+ std::vector<int64_t > outdim{1 , spec_power.nr (), spec_power.nc ()};
140+ auto result_size = spec_power.size ();
141+ auto out0 = stft_norm.Allocate (outdim);
142+ memcpy (out0, spec_power.steal_memory ().get (), result_size * sizeof (float ));
143+
144+ return {};
145+ }
146+
54147 static std::vector<float > hann_window (int N) {
55148 std::vector<float > window (N);
56149
@@ -137,10 +230,20 @@ class LogMel {
137230 */
138231 assert (stft_norm.Shape ().size () == 3 && stft_norm.Shape ()[0 ] == 1 );
139232 std::vector<int64_t > stft_shape = stft_norm.Shape ();
140- dlib::matrix<float > magnitudes (stft_norm.Shape ()[1 ], stft_norm.Shape ()[2 ] - 1 );
233+ int64_t n_fill_zero_col = stft_shape[1 ]; // if 8k, fill 4k - 8k hz with zeros
234+ int64_t additional_row = 0 ;
235+ if (mel_filters_.nc () > stft_shape[1 ]) {
236+ n_fill_zero_col = mel_filters_.nc () - stft_shape[1 ] - 1 ;
237+ additional_row = stft_shape[1 ] - 1 ;
238+ }
239+ dlib::matrix<float > magnitudes (stft_shape[1 ] + additional_row, stft_shape[2 ] - 1 );
141240 for (int i = 0 ; i < magnitudes.nr (); ++i) {
142- std::copy (stft_norm.Data () + i * stft_shape[2 ], stft_norm.Data () + (i + 1 ) * stft_shape[2 ] - 1 ,
143- magnitudes.begin () + i * magnitudes.nc ());
241+ if (i < n_fill_zero_col) {
242+ std::copy (stft_norm.Data () + i * stft_shape[2 ], stft_norm.Data () + (i + 1 ) * stft_shape[2 ] - 1 ,
243+ magnitudes.begin () + i * magnitudes.nc ());
244+ } else {
245+ std::fill (magnitudes.begin () + i * magnitudes.nc (), magnitudes.begin () + (i + 1 ) * magnitudes.nc (), 0 .0f );
246+ }
144247 }
145248
146249 dlib::matrix<float > mel_spec = mel_filters_ * magnitudes;
@@ -191,7 +294,8 @@ class LogMel {
191294 }
192295
193296 // Function to compute the Mel filterbank
194- static dlib::matrix<float > MelFilterBank (int n_fft, int n_mels, int sr = 16000 , float min_mel = 0 ,
297+ static dlib::matrix<float > MelFilterBank (int n_fft, int n_mels,
298+ int sr = 16000 , float min_mel = 0 ,
195299 float max_mel = 45.245640471924965 ) {
196300 // Initialize the filterbank matrix
197301 dlib::matrix<float > fbank (n_mels, n_fft / 2 + 1 );
@@ -312,7 +416,7 @@ class Phi4AudioEmbed {
312416 ortc::Tensor<float > stft_norm (&CppAllocator::Instance ());
313417 SpeechFeatures stft_normal;
314418 stft_normal.Init (sr_val == 8000 ? stft_normal_8k_attrs_: stft_normal_attrs_);
315- auto status = stft_normal.STFTNorm (pcm, stft_norm);
419+ auto status = stft_normal.SpeechLibSTFTNorm (pcm, stft_norm);
316420 if (!status.IsOk ()) {
317421 return status;
318422 }
0 commit comments