Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 137 additions & 3 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class SpecialTokenMap {
auto search_it = std::search(it, str.first.end(), std::boyer_moore_searcher(st.str.begin(), st.str.end()));
#endif
if (search_it == str.first.end()) {
new_split_res.emplace_back(std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos),
kInvalidTokenId);
new_split_res.emplace_back(
std::u32string_view(str.first.data() + search_pos, str.first.size() - search_pos), kInvalidTokenId);
break;
}

Expand Down Expand Up @@ -359,6 +359,117 @@ class PreTokenizerWithRegEx {
return {};
}

// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?
std::u32string_view Match_PHI4_Pattern_1() {
size_t i = 0;

// [^\r\n\p{L}\p{N}]?
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
i++;
}

size_t j = i;
// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*
const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt |
ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo |
ufal::unilib::unicode::M;
if (IsCategory(m_text[i], categories1)) {
for (; j < m_text.size(); ++j) {
if (!IsCategory(m_text[j], categories1)) break;
}
}

// [\p{Ll}\p{Lm}\p{Lo}\p{M}]+
const ufal::unilib::unicode::category_t categories2 =
ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M;

if (IsCategory(m_text[j], categories2)) {
for (; j < m_text.size(); ++j) {
if (!IsCategory(m_text[j], categories2)) break;
}
} else if (j > i && j > 0 && IsCategory(m_text[j - 1], categories2)) {
for (; j < m_text.size(); ++j) {
if (!IsCategory(m_text[j], categories2)) break;
}
} else {
return {};
}

i = j;
// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
i += 2;
} else if ((i + 2) < m_text.size()) {
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) &&
((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
i += 3;
}
}
}

std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?
std::u32string_view Match_PHI4_Pattern_2() {
size_t i = 0;

// [^\r\n\p{L}\p{N}]?
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
i++;
}

// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+
const ufal::unilib::unicode::category_t categories1 = ufal::unilib::unicode::Lu | ufal::unilib::unicode::Lt |
ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo |
ufal::unilib::unicode::M;
if (IsCategory(m_text[i], categories1)) {
for (; i < m_text.size(); ++i) {
if (!IsCategory(m_text[i], categories1)) break;
}
} else {
return {};
}

// [\p{Ll}\p{Lm}\p{Lo}\p{M}]*
const ufal::unilib::unicode::category_t categories2 =
ufal::unilib::unicode::Ll | ufal::unilib::unicode::Lm | ufal::unilib::unicode::Lo | ufal::unilib::unicode::M;
if (IsCategory(m_text[i], categories2)) {
for (; i < m_text.size(); ++i) {
if (!IsCategory(m_text[i], categories2)) break;
}
}

// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
i += 2;
} else if ((i + 2) < m_text.size()) {
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) &&
((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) &&
((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
i += 3;
}
}
}

std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// "(\p{N})"
std::u32string_view Match_General_Pattern_1() {
if (IsN(m_text[0])) {
Expand All @@ -376,6 +487,10 @@ class PreTokenizerWithRegEx {
auto patterns = std::vector<std::tuple<std::string_view, RegexMatchFunc>>{
{R"((?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]))",
&PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)",
&PreTokenizerWithRegEx::Match_PHI4_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?)",
&PreTokenizerWithRegEx::Match_PHI4_Pattern_2},
{R"((?i:'s|'t|'re|'ve|'m|'ll|'d))", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_1},
{R"('s|'t|'re|'ve|'m|'ll|'d)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?\p{L}+)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_2},
Expand All @@ -387,6 +502,7 @@ class PreTokenizerWithRegEx {
{R"(\s+(?!\S)|\s+)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_4},
{R"([\p{L}]+|[\p{N}])", &PreTokenizerWithRegEx::Match_CLIP_Pattern_1},
{R"([^\s\p{L}\p{N}]+)", &PreTokenizerWithRegEx::Match_CLIP_Pattern_2},
{R"(?[^\s\p{L}\p{N}]+[\r\n/]*)", &PreTokenizerWithRegEx::Match_LLAMA3_Pattern_4},
{R"(\p{N})", &PreTokenizerWithRegEx::Match_General_Pattern_1},
};

Expand Down Expand Up @@ -416,7 +532,7 @@ class PreTokenizerWithRegEx {
} else {
if (pattern_size < regex_compound.size()) {
assert(regex_compound[pattern_size] == '|');
pattern_size++; // let the pattern include the '|'
pattern_size++; // let the pattern include the '|'
}
}
regex_compound = regex_prefix + regex_compound.substr(pos + pattern_size);
Expand Down Expand Up @@ -501,11 +617,29 @@ class PreTokenizerWithRegEx {
public:
static bool IsRN(char32_t ch) { return ch == U'\r' || ch == U'\n'; }

static bool IsCategory(char32_t ch, ufal::unilib::unicode::category_t category) {
auto ch_category = ufal::unilib::unicode::category(ch);
return (ch_category & category) != 0;
}

static bool IsL(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return (category & ufal::unilib::unicode::L) != 0;
}

static bool IsLuLtLmLoM(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return ((category & ufal::unilib::unicode::Lu) != 0 || (category & ufal::unilib::unicode::Lt) != 0 ||
(category & ufal::unilib::unicode::Lm) != 0 || (category & ufal::unilib::unicode::Lo) != 0 ||
(category & ufal::unilib::unicode::M) != 0);
}

static bool IsLlLmLoM(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return ((category & ufal::unilib::unicode::Ll) != 0 || (category & ufal::unilib::unicode::Lm) != 0 ||
(category & ufal::unilib::unicode::Lo) != 0 || (category & ufal::unilib::unicode::M) != 0);
}

static bool IsN(char32_t ch) {
auto category = ufal::unilib::unicode::category(ch);
return (category & ufal::unilib::unicode::N) != 0;
Expand Down
20 changes: 20 additions & 0 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@ TEST(CApiTest, StreamApiTest) {
OrtxDispose(&tokenizer);
}

TEST(OrtxTokenizerTest, RegexTest) {
std::u32string str = U"You'll enjoy the concert.";
auto reg_splitter = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::u32string> res;
std::vector<std::u32string> out_tokens = {U"You'll", U" enjoy", U" the", U" concert"};

int64_t max_length = out_tokens.size();
reg_splitter->Set(str.c_str());
auto status = reg_splitter->Compile(R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)");
assert(status.IsOk());

while (static_cast<int64_t>(res.size()) < max_length) {
std::u32string_view tok = reg_splitter->GetNextToken();
res.push_back(ustring(tok));
}

EXPECT_EQ(res, out_tokens);
}

TEST(OrtxTokenizerTest, ClipTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/tokenizer/clip");
Expand Down
2 changes: 1 addition & 1 deletion test/test_pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_llama3_2_image_processing(self):
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")

# test sentence for tokenizer
tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61"
tokenizer_test_sentence = "I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61. You'll enjoy the concert."

def test_OLMa_tokenizer(self):
test_sentence = [self.tokenizer_test_sentence + " |||IP_ADDRESS|||"]
Expand Down
Loading