Skip to content

Commit 79a5685

Browse files
committed
fix speaker; optimize decoder for special tokens.
1 parent 05e179f commit 79a5685

File tree

5 files changed

+140
-24
lines changed

5 files changed

+140
-24
lines changed

docs/models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,16 @@ Please use `--format completion` for these models.
280280
[SNAC-24kHz](https://huggingface.co/mlx-community/snac_24khz/tree/556af1cd3b1c5f2d294f6aa9bb886245d7b716ac) is used as codec.
281281
Use these additional command line options when converting: `--name Orpheus-TTS -a Orpheus-TTS --snac_model /path/to/snac_24kHz`
282282

283+
Use `--set voice XX` to select voice `XX`, such as `tara`. [More info](https://github.com/canopyai/Orpheus-TTS?tab=readme-ov-file#prompting).
284+
283285
* OuteTTS:
284286
* [x] 1.0: [1B](https://huggingface.co/OuteAI/Llama-OuteTTS-1.0-1B/commit/911e296ce01148a01f3af9329163b0d298ac33a1), [0.6B](https://huggingface.co/OuteAI/OuteTTS-1.0-0.6B/tree/e7bcd87b0ca47fd8c46317c8f745a5e4e19c7b5c)
285287

286288
[DAC.speech.v1.0 1.5kbps](https://huggingface.co/ibm-research/DAC.speech.v1.0/commits/main) is used as codec.
287289
Use these additional command line options when converting: `--name OuteTTS -a OuteTTS --dac_model /path/to/dac`
288290

291+
Use `--set speaker /path/to/speaker.json` to select a speaker profile. [More info](https://github.com/edwko/OuteTTS/blob/main/docs/interface_usage.md#creating-custom-speaker-profiles).
292+
289293
## Multimodal Models
290294

291295
* Fuyu (`FuyuForCausalLM`)

models/oute.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ namespace tts_llama
500500
std::string normalized_text = text_normalization(text);
501501

502502
std::string prompt;
503-
if (!speaker.IsNull()) {
503+
if (!speaker.IsNull())
504+
{
504505
// Merge speaker text
505506
auto [merged_text, separator] = merge_speaker_text(normalized_text, speaker["text"].ToString());
506507
normalized_text = merged_text;
@@ -580,7 +581,7 @@ namespace tts_llama
580581

581582
void Tokenizer::encode(const std::string &text, std::vector<int> &ids) const
582583
{
583-
auto prompt = get_completion_prompt(text, json::JSON::_null);
584+
auto prompt = get_completion_prompt(text);
584585
BaseTokenizer::encode(prompt, ids);
585586
}
586587

@@ -689,7 +690,7 @@ namespace tts_qwen3
689690

690691
void Tokenizer::encode(const std::string &text, std::vector<int> &ids) const
691692
{
692-
auto prompt = tts_llama::get_completion_prompt(text, json::JSON::_null);
693+
auto prompt = tts_llama::get_completion_prompt(text);
693694
BaseTokenizer::encode(prompt, ids);
694695
}
695696

src/basics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@ namespace utils
7272
void parallel_for(int64_t start, int64_t end, std::function<void(int64_t)> func, int num_threads = 0);
7373

7474
std::string load_file(const char *fn);
75+
76+
//#define TIME_STAMP (std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count())
7577
}

src/tokenizer.cpp

Lines changed: 101 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cstring>
66
#include <limits>
77
#include <regex>
8+
#include <iostream>
89

910
#include "unicode.h"
1011
#include "chat.h"
@@ -474,6 +475,7 @@ size_t BPEProcessor2::Load(DataReader *data_reader, int n_vocab)
474475
vocab_.id_to_token.resize(piece_size);
475476
load_vocab_merges(vocab_, reader);
476477
build_special_token_cache(vocab_);
478+
searcher.rebuild(vocab_.special_tokens_cache);
477479

478480
return reader.get_total_size();
479481
}
@@ -694,37 +696,115 @@ const std::string BPEProcessor3::IdToPiece(int id) const
694696
}
695697
}
696698

697-
static std::string search_first_special_token(std::string &input, const _vocab &vocab, int &sp_tok_id)
699+
NearestKeywordSearcher::Node *NearestKeywordSearcher::make_tree(std::vector<Item> &items, char ch, int value)
698700
{
699-
sp_tok_id = -1;
700-
auto nearest_match = std::string::npos;
701-
for (auto & st: vocab.special_tokens_cache)
702-
{
703-
const auto & special_id = st.first;
704-
const auto & special_token = st.second;
701+
Node * r = new Node();
702+
r->ch = ch;
703+
r->value = items.size() < 1 ? value : -1;
705704

706-
auto match = input.find(special_token, 0);
705+
while (true)
706+
{
707+
bool flag = false;
708+
char tag = 0;
709+
int v = -1;
707710

708-
if (match < nearest_match)
711+
std::vector<Item> sub;
712+
for (int i = (int)items.size() - 1; i >= 0; i--)
709713
{
710-
nearest_match = match;
711-
sp_tok_id = special_id;
714+
if (items[i].s.size() < 1) continue;
715+
716+
if (!flag)
717+
{
718+
flag = true;
719+
tag = items[i].s[0];
720+
v = items[i].value;
721+
}
722+
else
723+
{
724+
if (items[i].s[0] != tag) continue;
725+
}
726+
727+
if (items[i].s.size() > 1)
728+
sub.emplace_back(items[i].s.substr(1), items[i].value);
729+
730+
// mark as visited
731+
items[i].s = "";
712732
}
733+
734+
if (!flag) break;
735+
736+
Node *child = make_tree(sub, tag, v);
737+
r->child.emplace_back(std::unique_ptr<Node>(child));
713738
}
714739

715-
if (sp_tok_id >= 0)
740+
std::sort(r->child.begin(), r->child.end(), [](auto &p1, auto &p2) { return p1->ch <= p2->ch; });
741+
742+
return r;
743+
}
744+
745+
void NearestKeywordSearcher::rebuild(const std::unordered_map<int, std::string> keywords)
746+
{
747+
root.reset(nullptr);
748+
749+
std::vector<Item> sub;
750+
751+
for (auto & st: keywords)
716752
{
717-
const auto & special_token = vocab.special_tokens_cache.at(sp_tok_id);
718-
std::string r = input.substr(0, nearest_match);
719-
input = input.substr(nearest_match + special_token.size());
720-
return r;
753+
sub.emplace_back(st.second, st.first);
721754
}
722-
else
755+
root.reset(make_tree(sub, 0, -1));
756+
}
757+
758+
int NearestKeywordSearcher::match(const std::string &input, int index, Node *node, int &level) const
759+
{
760+
if (node->child.size() < 1) return node->value;
761+
if (index >= (int)input.size()) return -1;
762+
const char ch = input[index];
763+
764+
int low = 0;
765+
int high = (int)node->child.size() - 1;
766+
while (high >= low)
767+
{
768+
// assuming no overflow
769+
int middle = (high + low) / 2;
770+
Node *n = node->child[middle].get();
771+
if (n->ch < ch)
772+
{
773+
low = middle + 1;
774+
}
775+
else if (ch < n->ch)
776+
{
777+
high = middle - 1;
778+
}
779+
else
780+
{
781+
level++;
782+
return match(input, index + 1, n, level);
783+
}
784+
}
785+
786+
return -1;
787+
}
788+
789+
std::string NearestKeywordSearcher::search(std::string &input, int &kw_id) const
790+
{
791+
int index = 0;
792+
while (index < (int)input.size())
723793
{
724-
std::string r(input);
725-
input = "";
726-
return r;
794+
int len = 0;
795+
kw_id = match(input, index, root.get(), len);
796+
if (kw_id >= 0)
797+
{
798+
std::string r = input.substr(0, index);
799+
input = input.substr(index + len);
800+
return r;
801+
}
802+
index++;
727803
}
804+
805+
std::string r(input);
806+
input = "";
807+
return r;
728808
}
729809

730810
int BPEProcessor2::DoEncode(const std::string &input,
@@ -734,7 +814,7 @@ int BPEProcessor2::DoEncode(const std::string &input,
734814
int sp_tok_id = -1;
735815
while (text.size() > 0)
736816
{
737-
auto leading = search_first_special_token(text, vocab_, sp_tok_id);
817+
auto leading = searcher.search(text, sp_tok_id);
738818
DoEncode2(leading, ids);
739819
if (sp_tok_id < 0) break;
740820
ids->push_back(sp_tok_id);

src/tokenizer.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,34 @@ class BPEProcessor1: public Processor
193193
std::vector<int> *ids) const override;
194194
};
195195

196+
class NearestKeywordSearcher
197+
{
198+
public:
199+
void rebuild(const std::unordered_map<int, std::string> keywords);
200+
201+
std::string search(std::string &input, int &kw_id) const;
202+
203+
protected:
204+
205+
struct Item
206+
{
207+
std::string s;
208+
int value;
209+
};
210+
struct Node
211+
{
212+
char ch;
213+
int value;
214+
std::vector<std::unique_ptr<Node>> child;
215+
};
216+
217+
Node *make_tree(std::vector<Item> &items, char ch, int value);
218+
219+
int match(const std::string &input, int index, Node *node, int &level) const;
220+
221+
std::unique_ptr<Node> root;
222+
};
223+
196224
class BPEProcessor2: public Processor
197225
{
198226
public:
@@ -212,6 +240,7 @@ class BPEProcessor2: public Processor
212240
std::vector<int> *ids) const;
213241

214242
std::vector<std::string> regex_exprs;
243+
NearestKeywordSearcher searcher;
215244
};
216245

217246
class BPEProcessor3: public BPEProcessor2

0 commit comments

Comments
 (0)