Skip to content

Commit 3b20852

Browse files
authored
Merge pull request #13 from yuyun2000/opt/melotts
Opt/melotts
2 parents 74603be + 9e7342f commit 3b20852

File tree

3 files changed

+233
-292
lines changed

3 files changed

+233
-292
lines changed

projects/llm_framework/main_melotts/src/main.cpp

Lines changed: 232 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "Lexicon.hpp"
1010
#include <ax_sys_api.h>
1111
#include "AudioFile.h"
12-
#include "SolaProcessor.h"
1312
#include "Lexicon.hpp"
1413

1514
#include <signal.h>
@@ -253,14 +252,19 @@ class llm_task {
253252
}
254253
return false;
255254
}
255+
SLOGI("Processing text: %s", msg_str.c_str());
256+
257+
// Convert text to phonemes and tones
256258
std::vector<int> phones_bef, tones_bef;
257259
lexicon_->convert(msg_str, phones_bef, tones_bef);
258-
// Add blank between words
259-
auto phones = intersperse(phones_bef, 0);
260-
auto tones = intersperse(tones_bef, 0);
261-
int phone_len = phones.size();
262-
int MELOTTS_LANG_IDS = MELOTTS_LANG_IDS_MAP[mode_config_.mode];
263-
std::vector<int> langids(phone_len, MELOTTS_LANG_IDS);
260+
auto phones = intersperse(phones_bef, 0);
261+
auto tones = intersperse(tones_bef, 0);
262+
int phone_len = phones.size();
263+
std::vector<int> langids(phone_len, 3);
264+
265+
SLOGI("Phoneme conversion completed, length: %d", phone_len);
266+
267+
// Run the encoder to generate hidden representations
264268
auto encoder_output =
265269
encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w,
266270
mode_config_.get_length_scale(), mode_config_.sdp_ratio);
@@ -269,66 +273,272 @@ class llm_task {
269273
auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo();
270274
auto zp_shape = zp_info.GetShape();
271275

272-
// Decoder parameters setup
273-
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
274-
int dec_len = zp_size / zp_shape[1];
275-
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
276+
SLOGI("Encoder output completed, shape: [%ld, %ld, %ld], expected audio length: %d", zp_shape[0],
277+
zp_shape[1], zp_shape[2], audio_len);
278+
279+
// Calculate decoder parameters
280+
int zp_size = decoder_->GetInputSize(0) / sizeof(float);
281+
int dec_len = zp_size / zp_shape[1];
282+
int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float);
283+
276284
const int pad_frames = 16;
277285
const int samples_per_frame = 512;
278-
const int effective_frames = dec_len - 2 * pad_frames;
286+
287+
SLOGI("Decoder configuration: frame length=%d, audio slice length=%d, pad length=%d, samples per frame=%d",
288+
dec_len, audio_slice_len, pad_frames, samples_per_frame);
289+
290+
const int effective_frames = dec_len - 2 * pad_frames;
291+
279292
int dec_slice_num =
280293
static_cast<int>(std::ceil(static_cast<double>(zp_shape[2]) / static_cast<double>(effective_frames)));
281-
SolaProcessor sola(pad_frames, samples_per_frame);
294+
295+
SLOGI("Will perform %d inferences, each with effective frames: %d", dec_slice_num, effective_frames);
296+
297+
// SOLA parameters setup
298+
const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length
299+
const int sola_search_frame = pad_frames * samples_per_frame; // Search window length
300+
const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length
301+
302+
// Create fade-in/fade-out windows for smooth transitions
303+
std::vector<float> fade_in_window(sola_buffer_frame);
304+
std::vector<float> fade_out_window(sola_buffer_frame);
305+
306+
for (int i = 0; i < sola_buffer_frame; i++) {
307+
fade_in_window[i] = static_cast<float>(i) / sola_buffer_frame;
308+
fade_out_window[i] = 1.0f - fade_in_window[i];
309+
}
310+
311+
// Initialize SOLA buffer
312+
std::vector<float> sola_buffer(sola_buffer_frame, 0.0f);
313+
bool first_frame = true;
314+
282315
std::vector<float> pcmlist;
283316

317+
// Main decoding loop - process each slice
284318
for (int i = 0; i < dec_slice_num; i++) {
319+
// Calculate start position for current batch input
285320
int input_start = i * effective_frames;
321+
// Consider forward padding, but ensure non-negative
286322
if (i > 0) {
287323
input_start -= pad_frames;
288324
}
289-
input_start = std::max(0, input_start);
325+
input_start = std::max(0, input_start);
326+
327+
// Actual input length
290328
int actual_len = std::min(dec_len, static_cast<int>(zp_shape[2] - input_start));
329+
330+
// Calculate effective output range (frame level)
331+
int output_start_frame, output_end_frame;
332+
333+
if (i == 0) {
334+
// First frame: skip padding at beginning
335+
output_start_frame = 0;
336+
output_end_frame = effective_frames - 1;
337+
} else if (i == dec_slice_num - 1) {
338+
// Last frame: calculate from current segment start
339+
output_start_frame = i * effective_frames;
340+
// Last frame extends to encoder's maximum output length
341+
output_end_frame = static_cast<int>(zp_shape[2]) - 1;
342+
} else {
343+
// Middle frames: standard calculation
344+
output_start_frame = i * effective_frames;
345+
output_end_frame = (i + 1) * effective_frames - 1;
346+
}
347+
348+
SLOGI("Inference #%d: input frame range=[%d-%d], actual length=%d, output frame range=[%d-%d]", i + 1,
349+
input_start, input_start + actual_len - 1, actual_len, output_start_frame, output_end_frame);
350+
351+
// Prepare decoder input, initialize all to zero
291352
std::vector<float> zp(zp_size, 0);
292353

354+
// Copy data to decoder input
293355
for (int n = 0; n < zp_shape[1]; n++) {
294356
int copy_size = std::min(actual_len, static_cast<int>(zp_shape[2] - input_start));
295357
if (copy_size > 0) {
296358
memcpy(zp.data() + n * dec_len, zp_data + n * zp_shape[2] + input_start,
297359
sizeof(float) * copy_size);
298360
}
299361
}
362+
300363
// Run decoder
301364
std::vector<float> decoder_output(audio_slice_len);
302365
decoder_->SetInput(zp.data(), 0);
303366
decoder_->SetInput(g_matrix.data(), 1);
367+
368+
SLOGI("Inference #%d: starting decoding...", i + 1);
369+
304370
if (0 != decoder_->Run()) {
371+
SLOGI("Inference #%d: decoding failed", i + 1);
305372
throw std::string("decoder_ RunSync error");
306373
}
374+
307375
decoder_->GetOutput(decoder_output.data(), 0);
308-
std::vector<float> processed_output = sola.ProcessFrame(decoder_output, i, dec_slice_num, actual_len);
309376

310-
pcmlist.insert(pcmlist.end(), processed_output.begin(), processed_output.end());
377+
// === SOLA Processing Logic ===
378+
if (first_frame) {
379+
// Special handling for first frame - should not skip initial content
380+
// First frame starts directly from decoder output without skipping
381+
int audio_start = 0; // Start from beginning, don't skip pad_frames
382+
383+
// Calculate data length for first frame
384+
// First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end
385+
// for next frame alignment
386+
int audio_len = decoder_output.size() - sola_buffer_frame;
387+
388+
// Boundary check
389+
audio_len = std::max(0, audio_len); // Ensure non-negative
390+
391+
// Add first frame data
392+
if (audio_len > 0) {
393+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start,
394+
decoder_output.begin() + audio_start + audio_len);
395+
}
396+
397+
// Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment
398+
int buffer_start = audio_len;
399+
400+
// Ensure sufficient data is available for copying
401+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
402+
std::copy(decoder_output.begin() + buffer_start,
403+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
404+
} else {
405+
// Possible case: first frame data is shorter than sola_buffer_frame
406+
int available = static_cast<int>(decoder_output.size() - buffer_start);
407+
if (available > 0) {
408+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
409+
// Fill with zeros
410+
std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f);
411+
} else {
412+
// Completely insufficient data, fill all with zeros
413+
std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f);
414+
}
415+
}
416+
417+
first_frame = false;
418+
419+
SLOGI(
420+
"Inference #%d: First frame processing, added %d samples from position %d to output, saved %d "
421+
"samples to SOLA buffer",
422+
i + 1, audio_len, audio_start, sola_buffer_frame);
423+
} else {
424+
// Non-first frame: SOLA alignment required
425+
int audio_start = pad_frames * samples_per_frame;
426+
427+
// 1. Prepare search window - beginning portion of current frame
428+
std::vector<float> search_window(sola_buffer_frame + sola_search_frame);
429+
std::copy(decoder_output.begin() + audio_start,
430+
decoder_output.begin() + audio_start + search_window.size(), search_window.begin());
431+
432+
// 2. Find best alignment point (calculate cross-correlation)
433+
int best_offset = 0;
434+
float best_correlation = -1.0;
435+
436+
for (int offset = 0; offset <= sola_search_frame; offset++) {
437+
float correlation = 0.0;
438+
float energy = 0.0;
439+
440+
for (int j = 0; j < sola_buffer_frame; j++) {
441+
correlation += sola_buffer[j] * search_window[j + offset];
442+
energy += search_window[j + offset] * search_window[j + offset];
443+
}
444+
445+
// Normalize correlation (avoid division by zero)
446+
float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f;
447+
448+
if (normalized_correlation > best_correlation) {
449+
best_correlation = normalized_correlation;
450+
best_offset = offset;
451+
}
452+
}
453+
454+
SLOGI("Inference #%d: SOLA found best alignment offset %d with correlation coefficient %f", i + 1,
455+
best_offset, best_correlation);
456+
457+
// 3. Apply alignment offset
458+
int aligned_start = audio_start + best_offset;
459+
460+
// 4. Smooth transition processing (crossfade in alignment region)
461+
std::vector<float> crossfade_region(sola_buffer_frame);
462+
463+
for (int j = 0; j < sola_buffer_frame; j++) {
464+
// Apply fade-in/fade-out window functions
465+
crossfade_region[j] =
466+
decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
467+
}
468+
469+
// 5. Add crossfade region to output
470+
pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end());
471+
472+
// 6. Add remaining valid audio data
473+
int remaining_start = aligned_start + sola_buffer_frame;
474+
int remaining_len = (i == dec_slice_num - 1)
475+
? (actual_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame
476+
: (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
477+
478+
// Boundary check
479+
remaining_len = std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));
480+
481+
if (remaining_len > 0) {
482+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
483+
decoder_output.begin() + remaining_start + remaining_len);
484+
}
485+
486+
// 7. Update SOLA buffer for next frame
487+
int buffer_start = remaining_start + remaining_len;
488+
489+
// Check if there's enough data for the next buffer
490+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
491+
std::copy(decoder_output.begin() + buffer_start,
492+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
493+
} else {
494+
// If insufficient, fill with zeros
495+
int avail = static_cast<int>(decoder_output.size() - buffer_start);
496+
if (avail > 0) {
497+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
498+
}
499+
std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f);
500+
}
501+
502+
SLOGI("Inference #%d: Added %d + %d samples to output, cumulative length: %zu", i + 1,
503+
sola_buffer_frame, remaining_len, pcmlist.size());
504+
}
311505
}
312506

313-
double src_ratio = (mode_config_.audio_rate * 1.0f) / (mode_config_.mode_rate * 1.0f);
507+
SLOGI("All inference completed, generated PCM length: %zu", pcmlist.size());
508+
509+
// Post-processing: resample and convert to int16
510+
double src_ratio =
511+
static_cast<double>(mode_config_.audio_rate) / static_cast<double>(mode_config_.mode_rate);
314512
std::vector<float> tmp_pcm((pcmlist.size() * src_ratio + 1));
315513
int len;
514+
515+
SLOGI("Starting audio resampling, source rate: %f, target rate: %f, ratio: %f",
516+
static_cast<float>(mode_config_.mode_rate), static_cast<float>(mode_config_.audio_rate), src_ratio);
517+
316518
resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio);
317519

520+
SLOGI("Resampling completed, length after resampling: %d", len);
521+
318522
// Convert to 16-bit PCM
319523
wav_pcm_data.reserve(len);
320524
std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data),
321-
[](const auto val) { return (int16_t)(val * INT16_MAX); });
525+
[](const auto val) { return static_cast<int16_t>(val * INT16_MAX); });
526+
527+
SLOGI("Final audio length: %zu samples", wav_pcm_data.size());
322528

323-
// Call callback function with output
324-
if (out_callback_)
325-
out_callback_(std::string((char *)wav_pcm_data.data(), wav_pcm_data.size() * sizeof(int16_t)), finish);
529+
// Call the output callback function with the result
530+
if (out_callback_) {
531+
out_callback_(
532+
std::string(reinterpret_cast<char *>(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)),
533+
finish);
534+
}
326535

536+
SLOGI("TTS processing completed, output callback invoked");
327537
} catch (const std::exception &e) {
328538
SLOGI("TTS processing exception: %s", e.what());
329539
return true;
330540
} catch (...) {
331-
SLOGI("TTS processing encountered unknown exception");
541+
SLOGI("TTS processing encountered an unknown exception");
332542
return true;
333543
}
334544
return false;

projects/llm_framework/main_melotts/src/runner/Lexicon.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <iostream>
1010
#include "../../../../../SDK/components/utilities/include/sample_log.h"
1111
// Debug logging switch - set to true to enable debug logs
12-
static bool DEBUG_LOGGING = false;
12+
static bool DEBUG_LOGGING = true;
1313
// Macro for debug logging
1414
#define DEBUG_LOG(fmt, ...) \
1515
do { \

0 commit comments

Comments
 (0)