From e00b1238c24a1cc26a3dda2e43528f3363ac16f0 Mon Sep 17 00:00:00 2001 From: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:24:19 -0800 Subject: [PATCH] Add phi-4 regex support (#877) * add phi4 regex support * minor changes * resolve comments * oob check * code clean up * fix the pattern match and format the code. * fix the unit test. --------- Co-authored-by: Sayan Shaw Co-authored-by: Wenbing Li --- operators/tokenizer/bpe_utils.hpp | 140 ++++++++++++++++++++++++++++- test/pp_api_test/test_tokenizer.cc | 20 +++++ test/test_pp_api.py | 2 +- 3 files changed, 158 insertions(+), 4 deletions(-) diff --git a/operators/tokenizer/bpe_utils.hpp b/operators/tokenizer/bpe_utils.hpp index 4003973f..c2996a69 100644 --- a/operators/tokenizer/bpe_utils.hpp +++ b/operators/tokenizer/bpe_utils.hpp @@ -56,8 +56,8 @@ class SpecialTokenMap { auto search_it = std::search(it, str.first.end(), std::boyer_moore_searcher(st.str.begin(), st.str.end())); #endif if (search_it == str.first.end()) { - new_split_res.emplace_back(std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos), - kInvalidTokenId); + new_split_res.emplace_back( + std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos), kInvalidTokenId); break; } @@ -359,6 +359,117 @@ class PreTokenizerWithRegEx { return {}; } + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)? + std::u32string_view Match_PHI4_Pattern_1() { + size_t i = 0; + + // [^\r\n\p{L}\p{N}]? + if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) { + i++; + } + + size_t j = i; + // [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]* + const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt | + ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | + ufal::unilib::unicode::M; + if (IsCategory(m_text[i], categories1)) { + for (; j < m_text.size(); ++j) { + if (!IsCategory(m_text[j], categories1)) break; + } + } + + // [\p{Ll}\p{Lm}\p{Lo}\p{M}]+ + const ufal::unilib::unicode::category_t categories2 = + ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M; + + if (IsCategory(m_text[j], categories2)) { + for (; j < m_text.size(); ++j) { + if (!IsCategory(m_text[j], categories2)) break; + } + } else if (j > i && j > 0 && IsCategory(m_text[j - 1], categories2)) { + for (; j < m_text.size(); ++j) { + if (!IsCategory(m_text[j], categories2)) break; + } + } else { + return {}; + } + + i = j; + // (?i:'s|'t|'re|'ve|'m|'ll|'d)? + if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) { + if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') || + (m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) { + i += 2; + } else if ((i + 2) < m_text.size()) { + if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) && + ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) || + (((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) && + ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) || + (((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) && + ((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) { + i += 3; + } + } + } + + std::u32string_view res = m_text.substr(0, i); + m_text = m_text.substr(i); + return res; + } + + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)? + std::u32string_view Match_PHI4_Pattern_2() { + size_t i = 0; + + // [^\r\n\p{L}\p{N}]? + if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) { + i++; + } + + // [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+ + const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt | + ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | + ufal::unilib::unicode::M; + if (IsCategory(m_text[i], categories1)) { + for (; i < m_text.size(); ++i) { + if (!IsCategory(m_text[i], categories1)) break; + } + } else { + return {}; + } + + // [\p{Ll}\p{Lm}\p{Lo}\p{M}]* + const ufal::unilib::unicode::category_t categories2 = + ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M; + if (IsCategory(m_text[i], categories2)) { + for (; i < m_text.size(); ++i) { + if (!IsCategory(m_text[i], categories2)) break; + } + } + + // (?i:'s|'t|'re|'ve|'m|'ll|'d)? + if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) { + if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') || + (m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) { + i += 2; + } else if ((i + 2) < m_text.size()) { + if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) && + ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) || + (((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) && + ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) || + (((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) && + ((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) { + i += 3; + } + } + } + + std::u32string_view res = m_text.substr(0, i); + m_text = m_text.substr(i); + return res; + } + // "(\p{N})" std::u32string_view Match_General_Pattern_1() { if (IsN(m_text[0])) { @@ -376,6 +487,10 @@ class PreTokenizerWithRegEx { auto patterns = std::vector>{ {R"((?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]))", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1}, + {R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)", + &PreTokenizerWithRegEx::Match_PHI4_Pattern_1}, + {R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?)", + &PreTokenizerWithRegEx::Match_PHI4_Pattern_2}, {R"((?i:'s|'t|'re|'ve|'m|'ll|'d))", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1}, {R"('s|'t|'re|'ve|'m|'ll|'d)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_1}, {R"([^\r\n\p{L}\p{N}]?\p{L}+)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_2}, @@ -387,6 +502,7 @@ class PreTokenizerWithRegEx { {R"(\s+(?!\S)|\s+)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_4}, {R"([\p{L}]+|[\p{N}])", &PreTokenizerWithRegEx::Match_CLIP_Pattern_1}, {R"([^\s\p{L}\p{N}]+)", &PreTokenizerWithRegEx::Match_CLIP_Pattern_2}, + {R"(?[^\s\p{L}\p{N}]+[\r\n/]*)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_4}, {R"(\p{N})", &PreTokenizerWithRegEx::Match_General_Pattern_1}, }; @@ -416,7 +532,7 @@ class PreTokenizerWithRegEx { } else { if (pattern_size < regex_compound.size()) { assert(regex_compound[pattern_size] == '|'); - pattern_size++; // let the pattern include the '|' + pattern_size++; // let the pattern include the '|' } } regex_compound = regex_prefix + regex_compound.substr(pos + pattern_size); @@ -501,11 +617,29 @@ class PreTokenizerWithRegEx { public: static bool IsRN(char32_t ch) { return ch == U'\r' || ch == U'\n'; } + static bool IsCategory(char32_t ch, ufal::unilib::unicode::category_t category) { + auto ch_category = ufal::unilib::unicode::category(ch); + return (ch_category & category) != 0; + } + static bool IsL(char32_t ch) { auto category = ufal::unilib::unicode::category(ch); return (category & ufal::unilib::unicode::L) != 0; } + static bool IsLuLtLmLoM(char32_t ch) { + auto category = ufal::unilib::unicode::category(ch); + return ((category & ufal::unilib::unicode::Lu) != 0 || (category & ufal::unilib::unicode::Lt) != 0 || + (category & ufal::unilib::unicode::Lm) != 0 || (category & ufal::unilib::unicode::Lo) != 0 || + (category & ufal::unilib::unicode::M) != 0); + } + + static bool IsLlLmLoM(char32_t ch) { + auto category = ufal::unilib::unicode::category(ch); + return ((category & ufal::unilib::unicode::Ll) != 0 || (category & ufal::unilib::unicode::Lm) != 0 || + (category & ufal::unilib::unicode::Lo) != 0 || (category & ufal::unilib::unicode::M) != 0); + } + static bool IsN(char32_t ch) { auto category = ufal::unilib::unicode::category(ch); return (category & ufal::unilib::unicode::N) != 0; diff --git a/test/pp_api_test/test_tokenizer.cc b/test/pp_api_test/test_tokenizer.cc index 17060442..4beb39b5 100644 --- a/test/pp_api_test/test_tokenizer.cc +++ b/test/pp_api_test/test_tokenizer.cc @@ -65,6 +65,26 @@ TEST(CApiTest, StreamApiTest) { OrtxDispose(&tokenizer); } +TEST(OrtxTokenizerTest, RegexTest) { + std::u32string str = U"You'll enjoy the concert."; + auto reg_splitter = std::make_unique(); + + std::vector res; + std::vector out_tokens = {U"You'll", U" enjoy", U" the", U" concert"}; + + int64_t max_length = out_tokens.size(); + reg_splitter->Set(str.c_str()); + auto status = reg_splitter->Compile(R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)"); + assert(status.IsOk()); + + while (static_cast(res.size()) < max_length) { + std::u32string_view tok = reg_splitter->GetNextToken(); + res.push_back(ustring(tok)); + } + + EXPECT_EQ(res, out_tokens); +} + TEST(OrtxTokenizerTest, ClipTokenizer) { auto tokenizer = std::make_unique(); auto status = tokenizer->Load("data/tokenizer/clip"); diff --git a/test/test_pp_api.py b/test/test_pp_api.py index 7abca802..4b45cd44 100644 --- a/test/test_pp_api.py +++ b/test/test_pp_api.py @@ -120,7 +120,7 @@ def test_llama3_2_image_processing(self): a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png") # test sentence for tokenizer - tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61" + tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61. You'll enjoy the concert." def test_OLMa_tokenizer(self): test_sentence = [self.tokenizer_test_sentence + " |||IP_ADDRESS|||"]