Skip to content

Commit e00b123

Browse files
sayanshaw24Sayan Shawwenbingl
authored
Add phi-4 regex support (#877)
* add phi4 regex support * minor changes * resolve comments * oob check * code clean up * fix the pattern match and format the code. * fix the unit test. --------- Co-authored-by: Sayan Shaw <[email protected]> Co-authored-by: Wenbing Li <[email protected]>
1 parent f8f3ae9 commit e00b123

File tree

3 files changed

+158
-4
lines changed

3 files changed

+158
-4
lines changed

operators/tokenizer/bpe_utils.hpp

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class SpecialTokenMap {
5656
auto search_it = std::search(it, str.first.end(), std::boyer_moore_searcher(st.str.begin(), st.str.end()));
5757
#endif
5858
if (search_it == str.first.end()) {
59-
new_split_res.emplace_back(std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos),
60-
kInvalidTokenId);
59+
new_split_res.emplace_back(
60+
std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos), kInvalidTokenId);
6161
break;
6262
}
6363

@@ -359,6 +359,117 @@ class PreTokenizerWithRegEx {
359359
return {};
360360
}
361361

362+
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?
363+
std::u32string_view Match_PHI4_Pattern_1() {
364+
size_t i = 0;
365+
366+
// [^\r\n\p{L}\p{N}]?
367+
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
368+
i++;
369+
}
370+
371+
size_t j = i;
372+
// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*
373+
const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt |
374+
ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo |
375+
ufal::unilib::unicode::M;
376+
if (IsCategory(m_text[i], categories1)) {
377+
for (; j < m_text.size(); ++j) {
378+
if (!IsCategory(m_text[j], categories1)) break;
379+
}
380+
}
381+
382+
// [\p{Ll}\p{Lm}\p{Lo}\p{M}]+
383+
const ufal::unilib::unicode::category_t categories2 =
384+
ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M;
385+
386+
if (IsCategory(m_text[j], categories2)) {
387+
for (; j < m_text.size(); ++j) {
388+
if (!IsCategory(m_text[j], categories2)) break;
389+
}
390+
} else if (j > i && j > 0 && IsCategory(m_text[j - 1], categories2)) {
391+
for (; j < m_text.size(); ++j) {
392+
if (!IsCategory(m_text[j], categories2)) break;
393+
}
394+
} else {
395+
return {};
396+
}
397+
398+
i = j;
399+
// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
400+
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
401+
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
402+
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
403+
i += 2;
404+
} else if ((i + 2) < m_text.size()) {
405+
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) &&
406+
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
407+
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) &&
408+
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
409+
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) &&
410+
((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
411+
i += 3;
412+
}
413+
}
414+
}
415+
416+
std::u32string_view res = m_text.substr(0, i);
417+
m_text = m_text.substr(i);
418+
return res;
419+
}
420+
421+
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?
422+
std::u32string_view Match_PHI4_Pattern_2() {
423+
size_t i = 0;
424+
425+
// [^\r\n\p{L}\p{N}]?
426+
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
427+
i++;
428+
}
429+
430+
// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+
431+
const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt |
432+
ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo |
433+
ufal::unilib::unicode::M;
434+
if (IsCategory(m_text[i], categories1)) {
435+
for (; i < m_text.size(); ++i) {
436+
if (!IsCategory(m_text[i], categories1)) break;
437+
}
438+
} else {
439+
return {};
440+
}
441+
442+
// [\p{Ll}\p{Lm}\p{Lo}\p{M}]*
443+
const ufal::unilib::unicode::category_t categories2 =
444+
ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M;
445+
if (IsCategory(m_text[i], categories2)) {
446+
for (; i < m_text.size(); ++i) {
447+
if (!IsCategory(m_text[i], categories2)) break;
448+
}
449+
}
450+
451+
// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
452+
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
453+
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
454+
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
455+
i += 2;
456+
} else if ((i + 2) < m_text.size()) {
457+
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) &&
458+
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
459+
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) &&
460+
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
461+
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) &&
462+
((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
463+
i += 3;
464+
}
465+
}
466+
}
467+
468+
std::u32string_view res = m_text.substr(0, i);
469+
m_text = m_text.substr(i);
470+
return res;
471+
}
472+
362473
// "(\p{N})"
363474
std::u32string_view Match_General_Pattern_1() {
364475
if (IsN(m_text[0])) {
@@ -376,6 +487,10 @@ class PreTokenizerWithRegEx {
376487
auto patterns = std::vector<std::tuple<std::string_view, RegexMatchFunc>>{
377488
{R"((?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]))",
378489
&PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1},
490+
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)",
491+
&PreTokenizerWithRegEx::Match_PHI4_Pattern_1},
492+
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?)",
493+
&PreTokenizerWithRegEx::Match_PHI4_Pattern_2},
379494
{R"((?i:'s|'t|'re|'ve|'m|'ll|'d))", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1},
380495
{R"('s|'t|'re|'ve|'m|'ll|'d)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_1},
381496
{R"([^\r\n\p{L}\p{N}]?\p{L}+)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_2},
@@ -387,6 +502,7 @@ class PreTokenizerWithRegEx {
387502
{R"(\s+(?!\S)|\s+)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_4},
388503
{R"([\p{L}]+|[\p{N}])", &PreTokenizerWithRegEx::Match_CLIP_Pattern_1},
389504
{R"([^\s\p{L}\p{N}]+)", &PreTokenizerWithRegEx::Match_CLIP_Pattern_2},
505+
{R"(?[^\s\p{L}\p{N}]+[\r\n/]*)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_4},
390506
{R"(\p{N})", &PreTokenizerWithRegEx::Match_General_Pattern_1},
391507
};
392508

@@ -416,7 +532,7 @@ class PreTokenizerWithRegEx {
416532
} else {
417533
if (pattern_size < regex_compound.size()) {
418534
assert(regex_compound[pattern_size] == '|');
419-
pattern_size++; // let the pattern include the '|'
535+
pattern_size++; // let the pattern include the '|'
420536
}
421537
}
422538
regex_compound = regex_prefix + regex_compound.substr(pos + pattern_size);
@@ -501,11 +617,29 @@ class PreTokenizerWithRegEx {
501617
public:
502618
static bool IsRN(char32_t ch) { return ch == U'\r' || ch == U'\n'; }
503619

620+
static bool IsCategory(char32_t ch, ufal::unilib::unicode::category_t category) {
621+
auto ch_category = ufal::unilib::unicode::category(ch);
622+
return (ch_category & category) != 0;
623+
}
624+
504625
static bool IsL(char32_t ch) {
505626
auto category = ufal::unilib::unicode::category(ch);
506627
return (category & ufal::unilib::unicode::L) != 0;
507628
}
508629

630+
static bool IsLuLtLmLoM(char32_t ch) {
631+
auto category = ufal::unilib::unicode::category(ch);
632+
return ((category & ufal::unilib::unicode::Lu) != 0 || (category & ufal::unilib::unicode::Lt) != 0 ||
633+
(category & ufal::unilib::unicode::Lm) != 0 || (category & ufal::unilib::unicode::Lo) != 0 ||
634+
(category & ufal::unilib::unicode::M) != 0);
635+
}
636+
637+
static bool IsLlLmLoM(char32_t ch) {
638+
auto category = ufal::unilib::unicode::category(ch);
639+
return ((category & ufal::unilib::unicode::Ll) != 0 || (category & ufal::unilib::unicode::Lm) != 0 ||
640+
(category & ufal::unilib::unicode::Lo) != 0 || (category & ufal::unilib::unicode::M) != 0);
641+
}
642+
509643
static bool IsN(char32_t ch) {
510644
auto category = ufal::unilib::unicode::category(ch);
511645
return (category & ufal::unilib::unicode::N) != 0;

test/pp_api_test/test_tokenizer.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,26 @@ TEST(CApiTest, StreamApiTest) {
6565
OrtxDispose(&tokenizer);
6666
}
6767

68+
TEST(OrtxTokenizerTest, RegexTest) {
69+
std::u32string str = U"You'll enjoy the concert.";
70+
auto reg_splitter = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();
71+
72+
std::vector<std::u32string> res;
73+
std::vector<std::u32string> out_tokens = {U"You'll", U" enjoy", U" the", U" concert"};
74+
75+
int64_t max_length = out_tokens.size();
76+
reg_splitter->Set(str.c_str());
77+
auto status = reg_splitter->Compile(R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)");
78+
assert(status.IsOk());
79+
80+
while (static_cast<int64_t>(res.size()) < max_length) {
81+
std::u32string_view tok = reg_splitter->GetNextToken();
82+
res.push_back(ustring(tok));
83+
}
84+
85+
EXPECT_EQ(res, out_tokens);
86+
}
87+
6888
TEST(OrtxTokenizerTest, ClipTokenizer) {
6989
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
7090
auto status = tokenizer->Load("data/tokenizer/clip");

test/test_pp_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_llama3_2_image_processing(self):
120120
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")
121121

122122
# test sentence for tokenizer
123-
tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61"
123+
tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61. You'll enjoy the concert."
124124

125125
def test_OLMa_tokenizer(self):
126126
test_sentence = [self.tokenizer_test_sentence + " |||IP_ADDRESS|||"]

0 commit comments

Comments
 (0)