Skip to content

Commit 05e179f

Browse files
committed
OuteTTS: specifying speaker
1 parent 1124d0a commit 05e179f

File tree

3 files changed

+35
-21
lines changed

3 files changed

+35
-21
lines changed

models/oute.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,6 @@ namespace dac
153153
}
154154
}
155155

156-
void set_additional_args(const std::map<std::string, std::string> &args)
157-
{
158-
//
159-
}
160-
161156
bool load_more(ggml::type dtype, const json::JSON &config)
162157
{
163158
const auto cfg = config["dac_config.json"];
@@ -253,6 +248,8 @@ namespace tts_llama
253248
// https://github.com/edwko/OuteTTS/blob/main/outetts/version/v3/prompt_processor.py
254249
// https://github.com/edwko/OuteTTS/blob/main/outetts/version/v3/tokens.py
255250

251+
json::JSON speaker;
252+
256253
// Special Tokens structure
257254
struct SpecialTokens
258255
{
@@ -359,24 +356,24 @@ namespace tts_llama
359356
{
360357
std::vector<std::string> codes;
361358

362-
for (const auto& word_item : words.ObjectRange())
359+
for (const auto& word_item : words.ArrayRange())
363360
{
364-
std::string word = word_item.second["word"].ToString() + special_tokens.features;
365-
word += format_string(special_tokens.time, word_item.second["duration"].ToFloat());
361+
std::string word = word_item["word"].ToString() + special_tokens.features;
362+
word += format_string(special_tokens.time, word_item["duration"].ToFloat());
366363

367364
// Add features
368-
std::vector<std::string> features = get_features(word_item.second["features"], special_tokens);
365+
std::vector<std::string> features = get_features(word_item["features"], special_tokens);
369366
for (const auto& feature : features)
370367
{
371368
word += feature;
372369
}
373370

374371
// Add pairs of c1 and c2
375372
std::vector<std::string> pairs;
376-
for (size_t idx = 0; idx < word_item.second["c1"].length(); idx++)
373+
for (size_t idx = 0; idx < word_item["c1"].length(); idx++)
377374
{
378-
std::string c1 = format_string(special_tokens.c1, (int)word_item.second["c1"][(unsigned)idx].ToInt());
379-
std::string c2 = format_string(special_tokens.c2, (int)word_item.second["c2"][(unsigned)idx].ToInt());
375+
std::string c1 = format_string(special_tokens.c1, (int)word_item["c1"][(unsigned)idx].ToInt());
376+
std::string c2 = format_string(special_tokens.c2, (int)word_item["c2"][(unsigned)idx].ToInt());
380377
pairs.push_back(c1 + c2);
381378
}
382379

@@ -533,15 +530,19 @@ namespace tts_llama
533530
return prompt;
534531
}
535532

536-
static json::JSON speaker_from_file(const std::string & speaker_file)
533+
std::string get_completion_prompt(const std::string& text)
537534
{
538-
std::ifstream file(speaker_file);
539-
CHATLLM_CHECK(file) << "Failed to open file: " << speaker_file;
540-
541-
std::stringstream buffer;
542-
buffer << file.rdbuf();
535+
json::JSON clone(speaker);
536+
return get_completion_prompt(text, clone);
537+
}
543538

544-
return json::JSON::Load(buffer.str());
539+
static void load_speaker_from_args(const std::map<std::string, std::string> &args)
540+
{
541+
auto x = args.find("speaker");
542+
if (x != args.end())
543+
{
544+
speaker = json::JSON::Load(utils::load_file(x->second.c_str()));
545+
}
545546
}
546547

547548
std::vector<std::vector<int>> extract_codebooks(const std::string& codes) {
@@ -647,7 +648,7 @@ namespace tts_llama
647648

648649
void set_additional_args(const std::map<std::string, std::string> &args) override
649650
{
650-
codec.set_additional_args(args);
651+
load_speaker_from_args(args);
651652
}
652653

653654
public:
@@ -733,7 +734,7 @@ namespace tts_qwen3
733734

734735
void set_additional_args(const std::map<std::string, std::string> &args) override
735736
{
736-
codec.set_additional_args(args);
737+
tts_llama::load_speaker_from_args(args);
737738
}
738739

739740
public:

src/basics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,6 @@ namespace utils
7070

7171
// for (i = start; i < end; i++) { func(i); }
7272
void parallel_for(int64_t start, int64_t end, std::function<void(int64_t)> func, int num_threads = 0);
73+
74+
std::string load_file(const char *fn);
7375
}

src/vectorstore.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,4 +457,15 @@ namespace utils
457457
}
458458
}
459459
}
460+
461+
std::string load_file(const char *fn)
462+
{
463+
std::ifstream file(fn);
464+
if (!file.is_open()) return "";
465+
466+
std::stringstream buffer;
467+
buffer << file.rdbuf();
468+
469+
return buffer.str();
470+
}
460471
}

0 commit comments

Comments
 (0)