Skip to content

Commit c8002f0

Browse files
author
Sayan Shaw
committed
add phi4 regex support
1 parent e8bf5a9 commit c8002f0

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

operators/tokenizer/bpe_tokenizer_model.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,8 @@ class BpeModel {
417417

418418
if (model_name == "Llama") {
419419
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
420+
} else if (model_name == "Phi") {
421+
return bpe::PreTokenizerWithRegEx::PHI4_REGEX_PATTERN;
420422
}
421423

422424
// by default, use the GPT2 pretokenizer regex

operators/tokenizer/bpe_utils.hpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class PreTokenizerWithRegEx {
9797
R"('s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+)";
9898
static constexpr const char LLAMA_REGEX_PATTERN[] =
9999
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+)";
100+
static constexpr const char PHI4_REGEX_PATTERN[] =
101+
R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|?[^\s\p{L}\p{N}]+[\r\n]*)";
100102

101103
PreTokenizerWithRegEx() = default;
102104

@@ -359,6 +361,114 @@ class PreTokenizerWithRegEx {
359361
return {};
360362
}
361363

364+
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?
365+
std::u32string_view Match_PHI4_Pattern_1() {
366+
size_t i = 0;
367+
368+
// [^\r\n\p{L}\p{N}]?
369+
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
370+
i++;
371+
}
372+
373+
// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*
374+
std::vector<ufal::unilib::unicode::category_t> categories1 = {ufal::unilib::unicode::Lu,
375+
ufal::unilib::unicode::Lt,
376+
ufal::unilib::unicode::Lm,
377+
ufal::unilib::unicode::Lo,
378+
ufal::unilib::unicode::M};
379+
while (std::find(categories1.begin(), categories1.end(), ufal::unilib::unicode::category(m_text[i])) != categories1.end()){
380+
i++;
381+
}
382+
383+
// [\p{Ll}\p{Lm}\p{Lo}\p{M}]+
384+
size_t j = i;
385+
std::vector<ufal::unilib::unicode::category_t> categories2 = {ufal::unilib::unicode::Ll,
386+
ufal::unilib::unicode::Lm,
387+
ufal::unilib::unicode::Lo,
388+
ufal::unilib::unicode::M};
389+
while (std::find(categories2.begin(), categories2.end(), ufal::unilib::unicode::category(m_text[i])) != categories2.end()){
390+
i++;
391+
}
392+
if (i == j){
393+
// No case match, return as this is a '+' category case (one or more occurrences must be found)
394+
std::u32string_view res = m_text.substr(0, i);
395+
m_text = m_text.substr(i);
396+
return res;
397+
}
398+
399+
// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
400+
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
401+
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
402+
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
403+
i += 2;
404+
} else if ((i + 2) < m_text.size()) {
405+
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
406+
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
407+
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) && ((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
408+
i += 3;
409+
}
410+
}
411+
}
412+
413+
std::u32string_view res = m_text.substr(0, i);
414+
m_text = m_text.substr(i);
415+
return res;
416+
}
417+
418+
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?
419+
std::u32string_view Match_PHI4_Pattern_2() {
420+
size_t i = 0;
421+
422+
// [^\r\n\p{L}\p{N}]?
423+
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
424+
i++;
425+
}
426+
427+
// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+
428+
size_t j = i;
429+
std::vector<ufal::unilib::unicode::category_t> categories1 = {ufal::unilib::unicode::Lu,
430+
ufal::unilib::unicode::Lt,
431+
ufal::unilib::unicode::Lm,
432+
ufal::unilib::unicode::Lo,
433+
ufal::unilib::unicode::M};
434+
while (std::find(categories1.begin(), categories1.end(), ufal::unilib::unicode::category(m_text[i])) != categories1.end()){
435+
i++;
436+
}
437+
if (i == j){
438+
// No case match, return as this is a '+' category case (one or more occurrences must be found)
439+
std::u32string_view res = m_text.substr(0, i);
440+
m_text = m_text.substr(i);
441+
return res;
442+
}
443+
444+
// [\p{Ll}\p{Lm}\p{Lo}\p{M}]*
445+
std::vector<ufal::unilib::unicode::category_t> categories2 = {ufal::unilib::unicode::Ll,
446+
ufal::unilib::unicode::Lm,
447+
ufal::unilib::unicode::Lo,
448+
ufal::unilib::unicode::M};
449+
while (std::find(categories2.begin(), categories2.end(), ufal::unilib::unicode::category(m_text[i])) != categories2.end()){
450+
i++;
451+
}
452+
453+
// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
454+
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
455+
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
456+
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
457+
i += 2;
458+
} else if ((i + 2) < m_text.size()) {
459+
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
460+
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
461+
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) && ((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
462+
i += 3;
463+
}
464+
}
465+
}
466+
467+
std::u32string_view res = m_text.substr(0, i);
468+
m_text = m_text.substr(i);
469+
return res;
470+
}
471+
362472
// "(\p{N})"
363473
std::u32string_view Match_General_Pattern_1() {
364474
if (IsN(m_text[0])) {
@@ -387,6 +497,8 @@ class PreTokenizerWithRegEx {
387497
{R"(\s+(?!\S)|\s+)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_4},
388498
{R"([\p{L}]+|[\p{N}])", &PreTokenizerWithRegEx::Match_CLIP_Pattern_1},
389499
{R"([^\s\p{L}\p{N}]+)", &PreTokenizerWithRegEx::Match_CLIP_Pattern_2},
500+
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)", &PreTokenizerWithRegEx::Match_PHI4_Pattern_1},
501+
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?)", &PreTokenizerWithRegEx::Match_PHI4_Pattern_2},
390502
{R"(\p{N})", &PreTokenizerWithRegEx::Match_General_Pattern_1},
391503
};
392504

test/pp_api_test/test_tokenizer.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ TEST(CApiTest, StreamApiTest) {
6565
OrtxDispose(&tokenizer);
6666
}
6767

68+
TEST(OrtxTokenizerTest, RegexTest) {
69+
std::u32string str = U"You'll enjoy the concert.";
70+
auto reg_splitter = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();
71+
72+
std::vector<std::u32string> res;
73+
std::vector<std::u32string> out_tokens = {U"You'll"};
74+
75+
int64_t max_length = out_tokens.size();
76+
reg_splitter->Set(str.c_str());
77+
auto regex_expr = reg_splitter->PHI4_REGEX_PATTERN;
78+
auto status = reg_splitter->Compile(R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)");
79+
assert(status.IsOk());
80+
81+
while (static_cast<int64_t>(res.size()) < max_length) {
82+
std::u32string_view tok = reg_splitter->GetNextToken();
83+
res.push_back(ustring(tok));
84+
}
85+
EXPECT_EQ(res, out_tokens);
86+
}
87+
6888
TEST(OrtxTokenizerTest, ClipTokenizer) {
6989
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
7090
auto status = tokenizer->Load("data/tokenizer/clip");

0 commit comments

Comments
 (0)