Skip to content

Commit 7d392a3

Browse files
committed
Merge branch 'dev' of github.com:m5stack/StackFlow into dev
2 parents 87744d2 + 6e503a1 commit 7d392a3

File tree

2 files changed

+214
-292
lines changed

2 files changed

+214
-292
lines changed

projects/llm_framework/main_melotts/src/main.cpp

Lines changed: 214 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "Lexicon.hpp"
1010
#include <ax_sys_api.h>
1111
#include "AudioFile.h"
12-
#include "SolaProcessor.h"
1312
#include "Lexicon.hpp"
1413

1514
#include <signal.h>
@@ -253,14 +252,16 @@ class llm_task {
253252
}
254253
return false;
255254
}
255+
256+
// Convert text to phonemes and tones
256257
std::vector<int> phones_bef, tones_bef;
257258
lexicon_->convert(msg_str, phones_bef, tones_bef);
258-
// Add blank between words
259-
auto phones = intersperse(phones_bef, 0);
260-
auto tones = intersperse(tones_bef, 0);
261-
int phone_len = phones.size();
262-
int MELOTTS_LANG_IDS = MELOTTS_LANG_IDS_MAP[mode_config_.mode];
263-
std::vector<int> langids(phone_len, MELOTTS_LANG_IDS);
259+
auto phones = intersperse(phones_bef, 0);
260+
auto tones = intersperse(tones_bef, 0);
261+
int phone_len = phones.size();
262+
std::vector<int> langids(phone_len, 3);
263+
264+
// Run the encoder to generate hidden representations
264265
auto encoder_output =
265266
encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w,
266267
mode_config_.get_length_scale(), mode_config_.sdp_ratio);
@@ -269,66 +270,256 @@ class llm_task {
269270
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
270271
auto zp_shape = zp_info.GetShape();
271272

272-
// Decoder parameters setup
273-
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
274-
int dec_len = zp_size / zp_shape[1];
275-
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
276-
const int pad_frames = 16;
273+
// Calculate decoder parameters
274+
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
275+
int dec_len = zp_size / zp_shape[1];
276+
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
277+
278+
const int pad_frames = 24;
277279
const int samples_per_frame = 512;
278-
const int effective_frames = dec_len - 2 * pad_frames;
280+
281+
const int effective_frames = dec_len - 2 * pad_frames;
282+
279283
int dec_slice_num =
280284
static_cast<int>(std::ceil(static_cast<double>(zp_shape[2]) / static_cast<double>(effective_frames)));
281-
SolaProcessor sola(pad_frames, samples_per_frame);
285+
286+
// SOLA parameters setup
287+
const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length
288+
const int sola_search_frame = pad_frames * samples_per_frame; // Search window length
289+
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length
290+
291+
// Create fade-in/fade-out windows for smooth transitions
292+
std::vector<float> fade_in_window(sola_buffer_frame);
293+
std::vector<float> fade_out_window(sola_buffer_frame);
294+
295+
for (int i = 0; i < sola_buffer_frame; i++) {
296+
fade_in_window[i] = static_cast<float>(i) / sola_buffer_frame;
297+
fade_out_window[i] = 1.0f - fade_in_window[i];
298+
}
299+
300+
// Initialize SOLA buffer
301+
std::vector<float> sola_buffer(sola_buffer_frame, 0.0f);
302+
bool first_frame = true;
303+
282304
std::vector<float> pcmlist;
283305

306+
// Main decoding loop - process each slice
284307
for (int i = 0; i < dec_slice_num; i++) {
308+
// Calculate start position for current batch input
285309
int input_start = i * effective_frames;
310+
// Consider forward padding, but ensure non-negative
286311
if (i > 0) {
287312
input_start -= pad_frames;
288313
}
289-
input_start = std::max(0, input_start);
314+
input_start = std::max(0, input_start);
315+
316+
// Actual input length
290317
int actual_len = std::min(dec_len, static_cast<int>(zp_shape[2] - input_start));
318+
319+
// Calculate effective output range (frame level)
320+
int output_start_frame, output_end_frame;
321+
322+
if (i == 0) {
323+
// First frame: skip padding at beginning
324+
output_start_frame = 0;
325+
output_end_frame = effective_frames - 1;
326+
} else if (i == dec_slice_num - 1) {
327+
// Last frame: calculate from current segment start
328+
output_start_frame = i * effective_frames;
329+
// Last frame extends to encoder's maximum output length
330+
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
331+
} else {
332+
// Middle frames: standard calculation
333+
output_start_frame = i * effective_frames;
334+
output_end_frame = (i + 1) * effective_frames - 1;
335+
}
336+
// Prepare decoder input, initialize all to zero
291337
std::vector<float> zp(zp_size, 0);
292338

339+
// Copy data to decoder input
293340
for (int n = 0; n < zp_shape[1]; n++) {
294341
int copy_size = std::min(actual_len, static_cast<int>(zp_shape[2] - input_start));
295342
if (copy_size > 0) {
296343
memcpy(zp.data() + n * dec_len, zp_data + n * zp_shape[2] + input_start,
297344
sizeof(float) * copy_size);
298345
}
299346
}
347+
300348
// Run decoder
301349
std::vector<float> decoder_output(audio_slice_len);
302350
decoder_->SetInput(zp.data(), 0);
303351
decoder_->SetInput(g_matrix.data(), 1);
352+
304353
if (0 != decoder_->Run()) {
354+
SLOGI("Inference #%d: decoding failed", i + 1);
305355
throw std::string("decoder_ RunSync error");
306356
}
357+
307358
decoder_->GetOutput(decoder_output.data(), 0);
308-
std::vector<float> processed_output = sola.ProcessFrame(decoder_output, i, dec_slice_num, actual_len);
309359

310-
pcmlist.insert(pcmlist.end(), processed_output.begin(), processed_output.end());
360+
// === SOLA Processing Logic ===
361+
if (first_frame) {
362+
// Special handling for first frame - should not skip initial content
363+
// First frame starts directly from decoder output without skipping
364+
int audio_start = 0; // Start from beginning, don't skip pad_frames
365+
366+
// Calculate data length for first frame
367+
// First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end
368+
// for next frame alignment
369+
int audio_len = decoder_output.size() - sola_buffer_frame;
370+
371+
// Boundary check
372+
audio_len = std::max(0, audio_len); // Ensure non-negative
373+
374+
// Add first frame data
375+
if (audio_len > 0) {
376+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start,
377+
decoder_output.begin() + audio_start + audio_len);
378+
}
379+
380+
// Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment
381+
int buffer_start = audio_len;
382+
383+
// Ensure sufficient data is available for copying
384+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
385+
std::copy(decoder_output.begin() + buffer_start,
386+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
387+
} else {
388+
// Possible case: first frame data is shorter than sola_buffer_frame
389+
int available = static_cast<int>(decoder_output.size() - buffer_start);
390+
if (available > 0) {
391+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
392+
// Fill with zeros
393+
std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f);
394+
} else {
395+
// Completely insufficient data, fill all with zeros
396+
std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f);
397+
}
398+
}
399+
400+
first_frame = false;
401+
402+
} else {
403+
// Non-first frame: SOLA alignment required
404+
int audio_start = pad_frames * samples_per_frame;
405+
406+
// 1. Prepare search window - beginning portion of current frame
407+
std::vector<float> search_window(sola_buffer_frame + sola_search_frame);
408+
std::copy(decoder_output.begin() + audio_start,
409+
decoder_output.begin() + audio_start + search_window.size(), search_window.begin());
410+
411+
// 2. Find best alignment point (calculate cross-correlation)
412+
int best_offset = 0;
413+
float best_correlation = -1.0;
414+
415+
for (int offset = 0; offset <= sola_search_frame; offset++) {
416+
float correlation = 0.0;
417+
float energy = 0.0;
418+
419+
for (int j = 0; j < sola_buffer_frame; j++) {
420+
correlation += sola_buffer[j] * search_window[j + offset];
421+
energy += search_window[j + offset] * search_window[j + offset];
422+
}
423+
424+
// Normalize correlation (avoid division by zero)
425+
float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f;
426+
427+
if (normalized_correlation > best_correlation) {
428+
best_correlation = normalized_correlation;
429+
best_offset = offset;
430+
}
431+
}
432+
433+
// 3. Apply alignment offset
434+
int aligned_start = audio_start + best_offset;
435+
436+
// 4. Smooth transition processing (crossfade in alignment region)
437+
std::vector<float> crossfade_region(sola_buffer_frame);
438+
439+
for (int j = 0; j < sola_buffer_frame; j++) {
440+
// Apply fade-in/fade-out window functions
441+
crossfade_region[j] =
442+
decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
443+
}
444+
445+
// 5. Add crossfade region to output
446+
pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end());
447+
448+
int remaining_start = aligned_start + sola_buffer_frame;
449+
450+
if (i == dec_slice_num - 1) {
451+
int total_expected_samples = audio_len * samples_per_frame / 512;
452+
453+
int processed_samples = static_cast<int>(pcmlist.size());
454+
455+
int remaining_needed = total_expected_samples - processed_samples;
456+
remaining_needed = std::max(0, remaining_needed);
457+
458+
int remaining_len =
459+
std::min(remaining_needed, static_cast<int>(decoder_output.size() - remaining_start));
460+
461+
if (remaining_len > 0) {
462+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
463+
decoder_output.begin() + remaining_start + remaining_len);
464+
}
465+
466+
} else {
467+
int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
468+
469+
remaining_len =
470+
std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));
471+
472+
if (remaining_len > 0) {
473+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
474+
decoder_output.begin() + remaining_start + remaining_len);
475+
}
476+
477+
int buffer_start = remaining_start + remaining_len;
478+
479+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
480+
std::copy(decoder_output.begin() + buffer_start,
481+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
482+
} else {
483+
int avail = static_cast<int>(decoder_output.size() - buffer_start);
484+
if (avail > 0) {
485+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(),
486+
sola_buffer.begin());
487+
}
488+
std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f);
489+
}
490+
}
491+
}
492+
}
493+
494+
if (pcmlist.size() > audio_len) {
495+
pcmlist.resize(audio_len);
311496
}
312497

313-
double src_ratio = (mode_config_.audio_rate * 1.0f) / (mode_config_.mode_rate * 1.0f);
498+
// Post-processing: resample and convert to int16
499+
double src_ratio =
500+
static_cast<double>(mode_config_.audio_rate) / static_cast<double>(mode_config_.mode_rate);
314501
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
315502
int len;
503+
316504
resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);
317505

318506
// Convert to 16-bit PCM
319507
wav_pcm_data.reserve(len);
320508
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
321-
[](const auto val) { return (int16_t)(val * INT16_MAX); });
509+
[](const auto val) { return static_cast<int16_t>(val * INT16_MAX); });
322510

323-
// Call callback function with output
324-
if (out_callback_)
325-
out_callback_(std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t)), finish);
511+
// Call the output callback function with the result
512+
if (out_callback_) {
513+
out_callback_(
514+
std::string(reinterpret_cast<char *>(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)),
515+
finish);
516+
}
326517

327518
} catch (const std::exception &e) {
328519
SLOGI("TTS processing exception: %s", e.what());
329520
return true;
330521
} catch (...) {
331-
SLOGI("TTS processing encountered unknown exception");
522+
SLOGI("TTS processing encountered an unknown exception");
332523
return true;
333524
}
334525
return false;

0 commit comments

Comments
 (0)