Skip to content

Commit b069b10

Browse files
pwilkinngxsonsayap
authored
vocab: fix Gemma4 tokenizer (ggml-org#21343)
* seems to work * fix case with new line Co-authored-by: sayap <sokann@gmail.com> * gemma 4: fix pre tok regex --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: sayap <sokann@gmail.com>
1 parent 0c58ba3 commit b069b10

5 files changed

Lines changed: 69 additions & 9 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7464,9 +7464,6 @@ def set_vocab(self):
74647464

74657465
assert len(tokens) == vocab.vocab_size
74667466

7467-
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
7468-
# but I don't have time to dive into them right now;
7469-
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
74707467
self.gguf_writer.add_tokenizer_model("gemma4")
74717468
self.gguf_writer.add_token_list(tokens)
74727469
self.gguf_writer.add_token_scores(scores)

src/llama-vocab.cpp

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer {
493493
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
494494
};
495495
break;
496+
case LLAMA_VOCAB_PRE_TYPE_GEMMA4:
497+
// Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the
498+
// normalizer, then BPE merges run on the whole text without
499+
// word-level pre-splitting. We only need to split on newlines
500+
// since BPE merge lookup asserts no newlines in tokens.
501+
regex_exprs = {
502+
"[^\\n]+|[\\n]+",
503+
};
504+
byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding
505+
break;
496506
default:
497507
// default regex for BPE tokenization pre-processing
498508
regex_exprs = {
@@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
506516
}
507517

508518
std::vector<std::string> regex_exprs;
519+
bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8)
509520
};
510521

511522
struct llm_tokenizer_bpe_session {
@@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session {
550561

551562
void tokenize(const std::string & text, std::vector<llama_token> & output) {
552563
int final_prev_index = -1;
553-
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
564+
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode);
554565

555566
symbols_final.clear();
567+
auto tok_pre = vocab.get_pre_type();
556568

557569
for (const auto & word : word_collection) {
558570
work_queue = llm_bigram_bpe::queue();
@@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session {
565577
if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
566578
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
567579
offset = word.size();
580+
} else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) {
581+
// fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343
582+
auto tok = vocab.text_to_token(word);
583+
if (tok != LLAMA_TOKEN_NULL) {
584+
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
585+
offset = word.size();
586+
}
568587
}
569588

570589
while (offset < word.size()) {
@@ -1864,7 +1883,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
18641883
special_pad_id = 3; // <|plamo:pad|>
18651884
special_mask_id = LLAMA_TOKEN_NULL;
18661885
} else if (tokenizer_model == "gemma4") {
1867-
type = LLAMA_VOCAB_TYPE_SPM;
1886+
type = LLAMA_VOCAB_TYPE_BPE;
1887+
1888+
// read bpe merges and populate bpe ranks
1889+
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
1890+
if (merges_keyidx == -1) {
1891+
throw std::runtime_error("cannot find tokenizer merges in model file\n");
1892+
}
1893+
{
1894+
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
1895+
for (int i = 0; i < n_merges; i++) {
1896+
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
1897+
1898+
std::string first;
1899+
std::string second;
1900+
1901+
const size_t pos = word.find(' ', 1);
1902+
1903+
if (pos != std::string::npos) {
1904+
first = word.substr(0, pos);
1905+
second = word.substr(pos + 1);
1906+
}
1907+
1908+
bpe_ranks.emplace(std::make_pair(first, second), i);
1909+
}
1910+
}
18681911

18691912
// default special tokens (to be read from GGUF)
18701913
special_bos_id = LLAMA_TOKEN_NULL;
@@ -1874,14 +1917,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
18741917
special_pad_id = LLAMA_TOKEN_NULL;
18751918
special_mask_id = LLAMA_TOKEN_NULL;
18761919

1877-
tokenizer_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
1920+
tokenizer_pre = "gemma4";
18781921
} else {
18791922
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
18801923
}
18811924

18821925
// for now, only BPE models have pre-tokenizers
18831926
if (type == LLAMA_VOCAB_TYPE_BPE) {
18841927
add_space_prefix = false;
1928+
escape_whitespaces = false;
18851929
clean_spaces = true;
18861930
if (tokenizer_pre.empty()) {
18871931
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
@@ -1948,6 +1992,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
19481992
} else if (
19491993
tokenizer_pre == "jais-2") {
19501994
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
1995+
} else if (
1996+
tokenizer_pre == "gemma4") {
1997+
pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4;
1998+
escape_whitespaces = true;
19511999
} else if (
19522000
tokenizer_pre == "jina-v1-en" ||
19532001
tokenizer_pre == "jina-v2-code" ||
@@ -3045,6 +3093,10 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
30453093
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
30463094
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
30473095

3096+
if (escape_whitespaces) {
3097+
llama_escape_whitespace(text);
3098+
}
3099+
30483100
#ifdef PRETOKENIZERDEBUG
30493101
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
30503102
#endif
@@ -3224,6 +3276,12 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
32243276
return _try_copy(token_text.data(), token_text.size());
32253277
}
32263278
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
3279+
if (escape_whitespaces) {
3280+
// SPM-style BPE: tokens contain ▁ for spaces
3281+
std::string result = token_text;
3282+
llama_unescape_whitespace(result);
3283+
return _try_copy(result.data(), result.size());
3284+
}
32273285
std::string result = llama_decode_text(token_text);
32283286
return _try_copy(result.data(), result.size());
32293287
}

src/llama-vocab.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum llama_vocab_pre_type {
5858
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
5959
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
6060
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
61+
LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50,
6162
};
6263

6364
struct LLM_KV;

src/unicode.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ bool unicode_cpt_is_han(uint32_t cpt) {
912912
return false;
913913
}
914914

915-
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
915+
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode) {
916916
// unicode categories
917917
static const std::map<std::string, int> k_ucat_enum = {
918918
{ "\\p{N}", unicode_cpt_flags::NUMBER },
@@ -1099,5 +1099,9 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
10991099
start += offset;
11001100
}
11011101

1102-
return unicode_byte_encoding_process(bpe_words);
1102+
if (byte_encode) {
1103+
return unicode_byte_encoding_process(bpe_words);
1104+
}
1105+
1106+
return bpe_words;
11031107
}

src/unicode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt);
108108

109109
bool unicode_cpt_is_han(uint32_t cpt);
110110

111-
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
111+
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode = true);

0 commit comments

Comments
 (0)