Skip to content

Commit e5944f2

Browse files
authored
Merge pull request #14 from yuyun2000/opt/melotts
Fix SOLA algorithm implementation
2 parents 3b20852 + 835daf1 commit e5944f2

File tree

1 file changed

+49
-25
lines changed
  • projects/llm_framework/main_melotts/src

1 file changed

+49
-25
lines changed

projects/llm_framework/main_melotts/src/main.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -469,42 +469,66 @@ class llm_task {
469469
// 5. Add crossfade region to output
470470
pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end());
471471

472-
// 6. Add remaining valid audio data
473472
int remaining_start = aligned_start + sola_buffer_frame;
474-
int remaining_len = (i == dec_slice_num - 1)
475-
? (actual_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame
476-
: (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
477473

478-
// Boundary check
479-
remaining_len = std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));
474+
if (i == dec_slice_num - 1) {
475+
int total_expected_samples = audio_len * samples_per_frame / 512;
480476

481-
if (remaining_len > 0) {
482-
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
483-
decoder_output.begin() + remaining_start + remaining_len);
484-
}
477+
int processed_samples = static_cast<int>(pcmlist.size());
485478

486-
// 7. Update SOLA buffer for next frame
487-
int buffer_start = remaining_start + remaining_len;
479+
int remaining_needed = total_expected_samples - processed_samples;
480+
remaining_needed = std::max(0, remaining_needed);
481+
482+
int remaining_len =
483+
std::min(remaining_needed, static_cast<int>(decoder_output.size() - remaining_start));
484+
485+
SLOGI("Inference #%d (final): Expected total=%d, processed=%d, needed=%d, available=%d", i + 1,
486+
total_expected_samples, processed_samples, remaining_needed, remaining_len);
487+
488+
if (remaining_len > 0) {
489+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
490+
decoder_output.begin() + remaining_start + remaining_len);
491+
}
488492

489-
// Check if there's enough data for the next buffer
490-
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
491-
std::copy(decoder_output.begin() + buffer_start,
492-
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
493493
} else {
494-
// If insufficient, fill with zeros
495-
int avail = static_cast<int>(decoder_output.size() - buffer_start);
496-
if (avail > 0) {
497-
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin());
494+
int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
495+
496+
remaining_len =
497+
std::min(remaining_len, static_cast<int>(decoder_output.size() - remaining_start));
498+
499+
if (remaining_len > 0) {
500+
pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start,
501+
decoder_output.begin() + remaining_start + remaining_len);
502+
}
503+
504+
int buffer_start = remaining_start + remaining_len;
505+
506+
if (buffer_start + sola_buffer_frame <= decoder_output.size()) {
507+
std::copy(decoder_output.begin() + buffer_start,
508+
decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin());
509+
} else {
510+
int avail = static_cast<int>(decoder_output.size() - buffer_start);
511+
if (avail > 0) {
512+
std::copy(decoder_output.begin() + buffer_start, decoder_output.end(),
513+
sola_buffer.begin());
514+
}
515+
std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f);
498516
}
499-
std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f);
500-
}
501517

502-
SLOGI("Inference #%d: Added %d + %d samples to output, cumulative length: %zu", i + 1,
503-
sola_buffer_frame, remaining_len, pcmlist.size());
518+
SLOGI("Inference #%d: Added %d + %d samples to output, cumulative length: %zu", i + 1,
519+
sola_buffer_frame, remaining_len, pcmlist.size());
520+
}
504521
}
505522
}
506523

507-
SLOGI("All inference completed, generated PCM length: %zu", pcmlist.size());
524+
SLOGI("All inference completed, raw generated PCM length: %zu", pcmlist.size());
525+
526+
if (pcmlist.size() > audio_len) {
527+
SLOGI("Truncating output from %zu to %d samples as per encoder prediction", pcmlist.size(), audio_len);
528+
pcmlist.resize(audio_len);
529+
}
530+
531+
SLOGI("Final PCM length after truncation: %zu", pcmlist.size());
508532

509533
// Post-processing: resample and convert to int16
510534
double src_ratio =

0 commit comments

Comments
 (0)