Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ class BpeModel {

if (model_name == "Llama") {
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
} else if (model_name == "Phi") {
return bpe::PreTokenizerWithRegEx::PHI4_REGEX_PATTERN;
}

// by default, use the GPT2 pretokenizer regex
Expand Down
112 changes: 112 additions & 0 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class PreTokenizerWithRegEx {
R"('s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+)";
static constexpr const char LLAMA_REGEX_PATTERN[] =
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+)";
static constexpr const char PHI4_REGEX_PATTERN[] =
R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|?[^\s\p{L}\p{N}]+[\r\n]*)";

PreTokenizerWithRegEx() = default;

Expand Down Expand Up @@ -359,6 +361,114 @@ class PreTokenizerWithRegEx {
return {};
}

// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?
std::u32string_view Match_PHI4_Pattern_1() {
size_t i = 0;

// [^\r\n\p{L}\p{N}]?
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
i++;
}

// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*
std::vector<ufal::unilib::unicode::category_t> categories1 = {ufal::unilib::unicode::Lu,
ufal::unilib::unicode::Lt,
ufal::unilib::unicode::Lm,
ufal::unilib::unicode::Lo,
ufal::unilib::unicode::M};
while (std::find(categories1.begin(), categories1.end(), ufal::unilib::unicode::category(m_text[i])) != categories1.end()){
i++;
}

// [\p{Ll}\p{Lm}\p{Lo}\p{M}]+
size_t j = i;
std::vector<ufal::unilib::unicode::category_t> categories2 = {ufal::unilib::unicode::Ll,
ufal::unilib::unicode::Lm,
ufal::unilib::unicode::Lo,
ufal::unilib::unicode::M};
while (std::find(categories2.begin(), categories2.end(), ufal::unilib::unicode::category(m_text[i])) != categories2.end()){
i++;
}
if (i == j){
// No case match, return as this is a '+' category case (one or more occurrences must be found)
std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
i += 2;
} else if ((i + 2) < m_text.size()) {
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) && ((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
i += 3;
}
}
}

std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?
std::u32string_view Match_PHI4_Pattern_2() {
size_t i = 0;

// [^\r\n\p{L}\p{N}]?
if (!IsRN(m_text[i]) && !IsN(m_text[i]) && !IsL(m_text[i])) {
i++;
}

// [\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+
size_t j = i;
std::vector<ufal::unilib::unicode::category_t> categories1 = {ufal::unilib::unicode::Lu,
ufal::unilib::unicode::Lt,
ufal::unilib::unicode::Lm,
ufal::unilib::unicode::Lo,
ufal::unilib::unicode::M};
while (std::find(categories1.begin(), categories1.end(), ufal::unilib::unicode::category(m_text[i])) != categories1.end()){
i++;
}
if (i == j){
// No case match, return as this is a '+' category case (one or more occurrences must be found)
std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// [\p{Ll}\p{Lm}\p{Lo}\p{M}]*
std::vector<ufal::unilib::unicode::category_t> categories2 = {ufal::unilib::unicode::Ll,
ufal::unilib::unicode::Lm,
ufal::unilib::unicode::Lo,
ufal::unilib::unicode::M};
while (std::find(categories2.begin(), categories2.end(), ufal::unilib::unicode::category(m_text[i])) != categories2.end()){
i++;
}

// (?i:'s|'t|'re|'ve|'m|'ll|'d)?
if ((m_text[i] == U'\'') && ((i + 1) < m_text.size())) {
if ((m_text[i + 1] == U's') || (m_text[i + 1] == U't') || (m_text[i + 1] == U'm') || (m_text[i + 1] == U'd') ||
(m_text[i + 1] == U'S') || (m_text[i + 1] == U'T') || (m_text[i + 1] == U'M') || (m_text[i + 1] == U'D')) {
i += 2;
} else if ((i + 2) < m_text.size()) {
if ((((m_text[i + 1] == U'r') || (m_text[i + 1] == U'R')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'v') || (m_text[i + 1] == U'V')) && ((m_text[i + 2] == U'e') || (m_text[i + 2] == U'E'))) ||
(((m_text[i + 1] == U'l') || (m_text[i + 1] == U'L')) && ((m_text[i + 2] == U'l') || (m_text[i + 2] == U'L')))) {
i += 3;
}
}
}

std::u32string_view res = m_text.substr(0, i);
m_text = m_text.substr(i);
return res;
}

// "(\p{N})"
std::u32string_view Match_General_Pattern_1() {
if (IsN(m_text[0])) {
Expand Down Expand Up @@ -387,6 +497,8 @@ class PreTokenizerWithRegEx {
{R"(\s+(?!\S)|\s+)", &PreTokenizerWithRegEx::Match_GPT2_Pattern_4},
{R"([\p{L}]+|[\p{N}])", &PreTokenizerWithRegEx::Match_CLIP_Pattern_1},
{R"([^\s\p{L}\p{N}]+)", &PreTokenizerWithRegEx::Match_CLIP_Pattern_2},
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)", &PreTokenizerWithRegEx::Match_PHI4_Pattern_1},
{R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?)", &PreTokenizerWithRegEx::Match_PHI4_Pattern_2},
{R"(\p{N})", &PreTokenizerWithRegEx::Match_General_Pattern_1},
};

Expand Down
20 changes: 20 additions & 0 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@ TEST(CApiTest, StreamApiTest) {
OrtxDispose(&tokenizer);
}

TEST(OrtxTokenizerTest, RegexTest) {
std::u32string str = U"You'll enjoy the concert.";
auto reg_splitter = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::u32string> res;
std::vector<std::u32string> out_tokens = {U"You'll"};

int64_t max_length = out_tokens.size();
reg_splitter->Set(str.c_str());
auto regex_expr = reg_splitter->PHI4_REGEX_PATTERN;
auto status = reg_splitter->Compile(R"([^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?)");
assert(status.IsOk());

while (static_cast<int64_t>(res.size()) < max_length) {
std::u32string_view tok = reg_splitter->GetNextToken();
res.push_back(ustring(tok));
}
EXPECT_EQ(res, out_tokens);
}

TEST(OrtxTokenizerTest, ClipTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/tokenizer/clip");
Expand Down
Loading