Skip to content

Commit

Permalink
Add phi-4 regex support (#877)
Browse files Browse the repository at this point in the history
* add phi4 regex support

* minor changes

* resolve comments

* oob check

* code clean up

* fix the pattern match and format the code.

* fix the unit test.

---------

Co-authored-by: Sayan Shaw <[email protected]>
Co-authored-by: Wenbing Li <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2025
1 parent f8f3ae9 commit e00b123
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 4 deletions.
140 changes: 137 additions & 3 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class SpecialTokenMap {
auto search_it = std::search(it, str.first.end(), std::boyer_moore_searcher(st.str.begin(), st.str.end()));
#endif
if (search_it == str.first.end()) {
new_split_res.emplace_back(std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos),
kInvalidTokenId);
new_split_res.emplace_back(
std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos), kInvalidTokenId);
break;
}

Expand Down Expand Up @@ -359,6 +359,117 @@ class PreTokenizerWithRegEx {
return {};
}

// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?
std::u32string_view Match_PHI4_Pattern_1() {
size_t i = 0;

// [^\r\n\p{L}\p{N}]?
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
i++;
}

size_t j = i;
// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*
const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt |
ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo |
ufal::unilib::unicode::M;
if (IsCategory(m_text[i], categories1)) {
for (; j < m_text.size(); ++j) {
if (!IsCategory(m_text[j], categories1)) break;
}
}

// [\p{Ll}\p{Lm}\p{Lo}\p{M}]+
const ufal::unilib::unicode::category_t categories2 =
ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M;

if (IsCategory(m_text[j], categories2)) {
for (; j < m_text.size(); ++j) {
if (!IsCategory(m_text[j], categories2)) break;
}
} else if (j > i && j > 0 && IsCategory(m_text[j - 1], categories2)) {
for (; j < m_text.size(); ++j) {
if (!IsCategory(m_text[j], categories2)) break;
}
} else {
return {};
}

i = j;
// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
i += 2;
} else if ((i + 2) < m_text.size()) {
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) &&
((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
i += 3;
}
}
}

std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?
std::u32string_view Match_PHI4_Pattern_2() {
size_t i = 0;

// [^\r\n\p{L}\p{N}]?
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
i++;
}

// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+
const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt |
ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo |
ufal::unilib::unicode::M;
if (IsCategory(m_text[i], categories1)) {
for (; i < m_text.size(); ++i) {
if (!IsCategory(m_text[i], categories1)) break;
}
} else {
return {};
}

// [\p{Ll}\p{Lm}\p{Lo}\p{M}]*
const ufal::unilib::unicode::category_t categories2 =
ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M;
if (IsCategory(m_text[i], categories2)) {
for (; i < m_text.size(); ++i) {
if (!IsCategory(m_text[i], categories2)) break;
}
}

// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
i += 2;
} else if ((i + 2) < m_text.size()) {
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) &&
((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
i += 3;
}
}
}

std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// "(\p{N})"
std::u32string_view Match_General_Pattern_1() {
if (IsN(m_text[0])) {
Expand All @@ -376,6 +487,10 @@ class PreTokenizerWithRegEx {
auto patterns = std::vector<std::tuple<std::string_view, RegexMatchFunc>>{
{R"((?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]))",
&PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)",
&PreTokenizerWithRegEx::Match_PHI4_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?)",
&PreTokenizerWithRegEx::Match_PHI4_Pattern_2},
{R"((?i:'s|'t|'re|'ve|'m|'ll|'d))", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1},
{R"('s|'t|'re|'ve|'m|'ll|'d)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?\p{L}+)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_2},
Expand All @@ -387,6 +502,7 @@ class PreTokenizerWithRegEx {
{R"(\s+(?!\S)|\s+)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_4},
{R"([\p{L}]+|[\p{N}])", &PreTokenizerWithRegEx::Match_CLIP_Pattern_1},
{R"([^\s\p{L}\p{N}]+)", &PreTokenizerWithRegEx::Match_CLIP_Pattern_2},
{R"(?[^\s\p{L}\p{N}]+[\r\n/]*)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_4},
{R"(\p{N})", &PreTokenizerWithRegEx::Match_General_Pattern_1},
};

Expand Down Expand Up @@ -416,7 +532,7 @@ class PreTokenizerWithRegEx {
} else {
if (pattern_size < regex_compound.size()) {
assert(regex_compound[pattern_size] == '|');
pattern_size++; // let the pattern include the '|'
pattern_size++; // let the pattern include the '|'
}
}
regex_compound = regex_prefix + regex_compound.substr(pos + pattern_size);
Expand Down Expand Up @@ -501,11 +617,29 @@ class PreTokenizerWithRegEx {
public:
static bool IsRN(char32_t ch) { return ch == U'\r' || ch == U'\n'; }

static bool IsCategory(char32_t ch, ufal::unilib::unicode::category_t category) {
auto ch_category = ufal::unilib::unicode::category(ch);
return (ch_category & category) != 0;
}

static bool IsL(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return (category & ufal::unilib::unicode::L) != 0;
}

static bool IsLuLtLmLoM(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return ((category & ufal::unilib::unicode::Lu) != 0 || (category & ufal::unilib::unicode::Lt) != 0 ||
(category & ufal::unilib::unicode::Lm) != 0 || (category & ufal::unilib::unicode::Lo) != 0 ||
(category & ufal::unilib::unicode::M) != 0);
}

static bool IsLlLmLoM(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return ((category & ufal::unilib::unicode::Ll) != 0 || (category & ufal::unilib::unicode::Lm) != 0 ||
(category & ufal::unilib::unicode::Lo) != 0 || (category & ufal::unilib::unicode::M) != 0);
}

static bool IsN(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return (category & ufal::unilib::unicode::N) != 0;
Expand Down
20 changes: 20 additions & 0 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@ TEST(CApiTest, StreamApiTest) {
OrtxDispose(&tokenizer);
}

TEST(OrtxTokenizerTest, RegexTest) {
std::u32string str = U"You'll enjoy the concert.";
auto reg_splitter = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::u32string> res;
std::vector<std::u32string> out_tokens = {U"You'll", U" enjoy", U" the", U" concert"};

int64_t max_length = out_tokens.size();
reg_splitter->Set(str.c_str());
auto status = reg_splitter->Compile(R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)");
assert(status.IsOk());

while (static_cast<int64_t>(res.size()) < max_length) {
std::u32string_view tok = reg_splitter->GetNextToken();
res.push_back(ustring(tok));
}

EXPECT_EQ(res, out_tokens);
}

TEST(OrtxTokenizerTest, ClipTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/tokenizer/clip");
Expand Down
2 changes: 1 addition & 1 deletion test/test_pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_llama3_2_image_processing(self):
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")

# test sentence for tokenizer
tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61"
tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61. You'll enjoy the concert."

def test_OLMa_tokenizer(self):
test_sentence = [self.tokenizer_test_sentence + " |||IP_ADDRESS|||"]
Expand Down

0 comments on commit e00b123

Please sign in to comment.