9
9
#include " Lexicon.hpp"
10
10
#include < ax_sys_api.h>
11
11
#include " AudioFile.h"
12
- #include " SolaProcessor.h"
13
12
#include " Lexicon.hpp"
14
13
15
14
#include < signal.h>
@@ -253,14 +252,16 @@ class llm_task {
253
252
}
254
253
return false ;
255
254
}
255
+
256
+ // Convert text to phonemes and tones
256
257
std::vector<int > phones_bef, tones_bef;
257
258
lexicon_->convert (msg_str, phones_bef, tones_bef);
258
- // Add blank between words
259
- auto phones = intersperse (phones_bef , 0 );
260
- auto tones = intersperse (tones_bef, 0 );
261
- int phone_len = phones. size ( );
262
- int MELOTTS_LANG_IDS = MELOTTS_LANG_IDS_MAP[mode_config_. mode ];
263
- std::vector< int > langids (phone_len, MELOTTS_LANG_IDS);
259
+ auto phones = intersperse (phones_bef, 0 );
260
+ auto tones = intersperse (tones_bef , 0 );
261
+ int phone_len = phones. size ( );
262
+ std::vector< int > langids ( phone_len, 3 );
263
+
264
+ // Run the encoder to generate hidden representations
264
265
auto encoder_output =
265
266
encoder_->Run (phones, tones, langids, g_matrix, mode_config_.noise_scale , mode_config_.noise_scale_w ,
266
267
mode_config_.get_length_scale (), mode_config_.sdp_ratio );
@@ -269,66 +270,256 @@ class llm_task {
269
270
auto zp_info = encoder_output.at (0 ).GetTensorTypeAndShapeInfo ();
270
271
auto zp_shape = zp_info.GetShape ();
271
272
272
- // Decoder parameters setup
273
- int zp_size = decoder_->GetInputSize (0 ) / sizeof (float );
274
- int dec_len = zp_size / zp_shape[1 ];
275
- int audio_slice_len = decoder_->GetOutputSize (0 ) / sizeof (float );
276
- const int pad_frames = 16 ;
273
+ // Calculate decoder parameters
274
+ int zp_size = decoder_->GetInputSize (0 ) / sizeof (float );
275
+ int dec_len = zp_size / zp_shape[1 ];
276
+ int audio_slice_len = decoder_->GetOutputSize (0 ) / sizeof (float );
277
+
278
+ const int pad_frames = 24 ;
277
279
const int samples_per_frame = 512 ;
278
- const int effective_frames = dec_len - 2 * pad_frames;
280
+
281
+ const int effective_frames = dec_len - 2 * pad_frames;
282
+
279
283
int dec_slice_num =
280
284
static_cast <int >(std::ceil (static_cast <double >(zp_shape[2 ]) / static_cast <double >(effective_frames)));
281
- SolaProcessor sola (pad_frames, samples_per_frame);
285
+
286
+ // SOLA parameters setup
287
+ const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length
288
+ const int sola_search_frame = pad_frames * samples_per_frame; // Search window length
289
+ const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length
290
+
291
+ // Create fade-in/fade-out windows for smooth transitions
292
+ std::vector<float > fade_in_window (sola_buffer_frame);
293
+ std::vector<float > fade_out_window (sola_buffer_frame);
294
+
295
+ for (int i = 0 ; i < sola_buffer_frame; i++) {
296
+ fade_in_window[i] = static_cast <float >(i) / sola_buffer_frame;
297
+ fade_out_window[i] = 1 .0f - fade_in_window[i];
298
+ }
299
+
300
+ // Initialize SOLA buffer
301
+ std::vector<float > sola_buffer (sola_buffer_frame, 0 .0f );
302
+ bool first_frame = true ;
303
+
282
304
std::vector<float > pcmlist;
283
305
306
+ // Main decoding loop - process each slice
284
307
for (int i = 0 ; i < dec_slice_num; i++) {
308
+ // Calculate start position for current batch input
285
309
int input_start = i * effective_frames;
310
+ // Consider forward padding, but ensure non-negative
286
311
if (i > 0 ) {
287
312
input_start -= pad_frames;
288
313
}
289
- input_start = std::max (0 , input_start);
314
+ input_start = std::max (0 , input_start);
315
+
316
+ // Actual input length
290
317
int actual_len = std::min (dec_len, static_cast <int >(zp_shape[2 ] - input_start));
318
+
319
+ // Calculate effective output range (frame level)
320
+ int output_start_frame, output_end_frame;
321
+
322
+ if (i == 0 ) {
323
+ // First frame: skip padding at beginning
324
+ output_start_frame = 0 ;
325
+ output_end_frame = effective_frames - 1 ;
326
+ } else if (i == dec_slice_num - 1 ) {
327
+ // Last frame: calculate from current segment start
328
+ output_start_frame = i * effective_frames;
329
+ // Last frame extends to encoder's maximum output length
330
+ output_end_frame = static_cast <int >(zp_shape[2 ]) - 1 ;
331
+ } else {
332
+ // Middle frames: standard calculation
333
+ output_start_frame = i * effective_frames;
334
+ output_end_frame = (i + 1 ) * effective_frames - 1 ;
335
+ }
336
+ // Prepare decoder input, initialize all to zero
291
337
std::vector<float > zp (zp_size, 0 );
292
338
339
+ // Copy data to decoder input
293
340
for (int n = 0 ; n < zp_shape[1 ]; n++) {
294
341
int copy_size = std::min (actual_len, static_cast <int >(zp_shape[2 ] - input_start));
295
342
if (copy_size > 0 ) {
296
343
memcpy (zp.data () + n * dec_len, zp_data + n * zp_shape[2 ] + input_start,
297
344
sizeof (float ) * copy_size);
298
345
}
299
346
}
347
+
300
348
// Run decoder
301
349
std::vector<float > decoder_output (audio_slice_len);
302
350
decoder_->SetInput (zp.data (), 0 );
303
351
decoder_->SetInput (g_matrix.data (), 1 );
352
+
304
353
if (0 != decoder_->Run ()) {
354
+ SLOGI (" Inference #%d: decoding failed" , i + 1 );
305
355
throw std::string (" decoder_ RunSync error" );
306
356
}
357
+
307
358
decoder_->GetOutput (decoder_output.data (), 0 );
308
- std::vector<float > processed_output = sola.ProcessFrame (decoder_output, i, dec_slice_num, actual_len);
309
359
310
- pcmlist.insert (pcmlist.end (), processed_output.begin (), processed_output.end ());
360
+ // === SOLA Processing Logic ===
361
+ if (first_frame) {
362
+ // Special handling for first frame - should not skip initial content
363
+ // First frame starts directly from decoder output without skipping
364
+ int audio_start = 0 ; // Start from beginning, don't skip pad_frames
365
+
366
+ // Calculate data length for first frame
367
+ // First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end
368
+ // for next frame alignment
369
+ int audio_len = decoder_output.size () - sola_buffer_frame;
370
+
371
+ // Boundary check
372
+ audio_len = std::max (0 , audio_len); // Ensure non-negative
373
+
374
+ // Add first frame data
375
+ if (audio_len > 0 ) {
376
+ pcmlist.insert (pcmlist.end (), decoder_output.begin () + audio_start,
377
+ decoder_output.begin () + audio_start + audio_len);
378
+ }
379
+
380
+ // Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment
381
+ int buffer_start = audio_len;
382
+
383
+ // Ensure sufficient data is available for copying
384
+ if (buffer_start + sola_buffer_frame <= decoder_output.size ()) {
385
+ std::copy (decoder_output.begin () + buffer_start,
386
+ decoder_output.begin () + buffer_start + sola_buffer_frame, sola_buffer.begin ());
387
+ } else {
388
+ // Possible case: first frame data is shorter than sola_buffer_frame
389
+ int available = static_cast <int >(decoder_output.size () - buffer_start);
390
+ if (available > 0 ) {
391
+ std::copy (decoder_output.begin () + buffer_start, decoder_output.end (), sola_buffer.begin ());
392
+ // Fill with zeros
393
+ std::fill (sola_buffer.begin () + available, sola_buffer.end (), 0 .0f );
394
+ } else {
395
+ // Completely insufficient data, fill all with zeros
396
+ std::fill (sola_buffer.begin (), sola_buffer.end (), 0 .0f );
397
+ }
398
+ }
399
+
400
+ first_frame = false ;
401
+
402
+ } else {
403
+ // Non-first frame: SOLA alignment required
404
+ int audio_start = pad_frames * samples_per_frame;
405
+
406
+ // 1. Prepare search window - beginning portion of current frame
407
+ std::vector<float > search_window (sola_buffer_frame + sola_search_frame);
408
+ std::copy (decoder_output.begin () + audio_start,
409
+ decoder_output.begin () + audio_start + search_window.size (), search_window.begin ());
410
+
411
+ // 2. Find best alignment point (calculate cross-correlation)
412
+ int best_offset = 0 ;
413
+ float best_correlation = -1.0 ;
414
+
415
+ for (int offset = 0 ; offset <= sola_search_frame; offset++) {
416
+ float correlation = 0.0 ;
417
+ float energy = 0.0 ;
418
+
419
+ for (int j = 0 ; j < sola_buffer_frame; j++) {
420
+ correlation += sola_buffer[j] * search_window[j + offset];
421
+ energy += search_window[j + offset] * search_window[j + offset];
422
+ }
423
+
424
+ // Normalize correlation (avoid division by zero)
425
+ float normalized_correlation = (energy > 1e-8 ) ? correlation / std::sqrt (energy) : 0 .0f ;
426
+
427
+ if (normalized_correlation > best_correlation) {
428
+ best_correlation = normalized_correlation;
429
+ best_offset = offset;
430
+ }
431
+ }
432
+
433
+ // 3. Apply alignment offset
434
+ int aligned_start = audio_start + best_offset;
435
+
436
+ // 4. Smooth transition processing (crossfade in alignment region)
437
+ std::vector<float > crossfade_region (sola_buffer_frame);
438
+
439
+ for (int j = 0 ; j < sola_buffer_frame; j++) {
440
+ // Apply fade-in/fade-out window functions
441
+ crossfade_region[j] =
442
+ decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
443
+ }
444
+
445
+ // 5. Add crossfade region to output
446
+ pcmlist.insert (pcmlist.end (), crossfade_region.begin (), crossfade_region.end ());
447
+
448
+ int remaining_start = aligned_start + sola_buffer_frame;
449
+
450
+ if (i == dec_slice_num - 1 ) {
451
+ int total_expected_samples = audio_len * samples_per_frame / 512 ;
452
+
453
+ int processed_samples = static_cast <int >(pcmlist.size ());
454
+
455
+ int remaining_needed = total_expected_samples - processed_samples;
456
+ remaining_needed = std::max (0 , remaining_needed);
457
+
458
+ int remaining_len =
459
+ std::min (remaining_needed, static_cast <int >(decoder_output.size () - remaining_start));
460
+
461
+ if (remaining_len > 0 ) {
462
+ pcmlist.insert (pcmlist.end (), decoder_output.begin () + remaining_start,
463
+ decoder_output.begin () + remaining_start + remaining_len);
464
+ }
465
+
466
+ } else {
467
+ int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
468
+
469
+ remaining_len =
470
+ std::min (remaining_len, static_cast <int >(decoder_output.size () - remaining_start));
471
+
472
+ if (remaining_len > 0 ) {
473
+ pcmlist.insert (pcmlist.end (), decoder_output.begin () + remaining_start,
474
+ decoder_output.begin () + remaining_start + remaining_len);
475
+ }
476
+
477
+ int buffer_start = remaining_start + remaining_len;
478
+
479
+ if (buffer_start + sola_buffer_frame <= decoder_output.size ()) {
480
+ std::copy (decoder_output.begin () + buffer_start,
481
+ decoder_output.begin () + buffer_start + sola_buffer_frame, sola_buffer.begin ());
482
+ } else {
483
+ int avail = static_cast <int >(decoder_output.size () - buffer_start);
484
+ if (avail > 0 ) {
485
+ std::copy (decoder_output.begin () + buffer_start, decoder_output.end (),
486
+ sola_buffer.begin ());
487
+ }
488
+ std::fill (sola_buffer.begin () + avail, sola_buffer.end (), 0 .0f );
489
+ }
490
+ }
491
+ }
492
+ }
493
+
494
+ if (pcmlist.size () > audio_len) {
495
+ pcmlist.resize (audio_len);
311
496
}
312
497
313
- double src_ratio = (mode_config_.audio_rate * 1 .0f ) / (mode_config_.mode_rate * 1 .0f );
498
+ // Post-processing: resample and convert to int16
499
+ double src_ratio =
500
+ static_cast <double >(mode_config_.audio_rate ) / static_cast <double >(mode_config_.mode_rate );
314
501
std::vector<float > tmp_pcm ((pcmlist.size () * src_ratio + 1 ));
315
502
int len;
503
+
316
504
resample_audio (pcmlist.data (), pcmlist.size (), tmp_pcm.data (), &len, src_ratio);
317
505
318
506
// Convert to 16-bit PCM
319
507
wav_pcm_data.reserve (len);
320
508
std::transform (tmp_pcm.begin (), tmp_pcm.begin () + len, std::back_inserter (wav_pcm_data),
321
- [](const auto val) { return ( int16_t ) (val * INT16_MAX); });
509
+ [](const auto val) { return static_cast < int16_t > (val * INT16_MAX); });
322
510
323
- // Call callback function with output
324
- if (out_callback_)
325
- out_callback_ (std::string ((char *)wav_pcm_data.data (), wav_pcm_data.size () * sizeof (int16_t )), finish);
511
+ // Call the output callback function with the result
512
+ if (out_callback_) {
513
+ out_callback_ (
514
+ std::string (reinterpret_cast <char *>(wav_pcm_data.data ()), wav_pcm_data.size () * sizeof (int16_t )),
515
+ finish);
516
+ }
326
517
327
518
} catch (const std::exception &e) {
328
519
SLOGI (" TTS processing exception: %s" , e.what ());
329
520
return true ;
330
521
} catch (...) {
331
- SLOGI (" TTS processing encountered unknown exception" );
522
+ SLOGI (" TTS processing encountered an unknown exception" );
332
523
return true ;
333
524
}
334
525
return false ;
0 commit comments