Skip to content

Commit c8bb35d

Browse files
authored
a general regex match algorithm to fix all related issues (#874)
* a general regex match algorithm to fix all related issues * fix all issues and code clean up * revert the test case
1 parent f3f6caa commit c8bb35d

File tree

5 files changed

+275
-803
lines changed

5 files changed

+275
-803
lines changed

operators/tokenizer/bpe_kernels.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
263263
// Parse input
264264
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
265265
bpe::PreTokenizerWithRegEx reg_splitter;
266+
// NOTE: the pattern was already validated on loading json file.
267+
// safe to ingore the return value here.
268+
auto status = reg_splitter.Compile(bbpe_tokenizer_->GetPreTokenizerRegex(ModelName()));
269+
assert(status.IsOk());
266270

267271
for (auto& seg_id : special_token_split_res) {
268272
if (static_cast<int64_t>(res.size()) >= max_length) break;
@@ -287,13 +291,12 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
287291
}
288292

289293
while (static_cast<int64_t>(res.size()) < max_length) {
290-
std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName());
291-
auto [b, tok] = reg_splitter.GetNextToken(regex_expr);
292-
293-
if (!b) break;
294+
std::u32string_view tok = reg_splitter.GetNextToken();
295+
if (tok.empty()) {
296+
break;
297+
}
294298

295299
std::string utf8_token = std::string(ustring(tok));
296-
297300
size_t space_dif = 0;
298301
if (compute_offset_mapping) {
299302
// Handle special case for offset mapping

operators/tokenizer/bpe_tokenizer_model.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ class BpeModel {
7777
ORTX_JSON_RETURN_IF_NULL(&node, "pattern", iter_pattern);
7878
ORTX_JSON_RETURN_IF_NULL(iter_pattern, "Regex", regex_str);
7979
pre_tokenizer_regex_ = regex_str->get<std::string>();
80+
// Validate the regex pattern
81+
bpe::PreTokenizerWithRegEx pre_tokenizer;
82+
auto status = pre_tokenizer.Compile(pre_tokenizer_regex_);
83+
if (!status.IsOk()) {
84+
return status;
85+
}
8086
} else {
8187
if (pre_tokenizer_types_.count(pre_type) == 0) {
8288
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
@@ -146,7 +152,7 @@ class BpeModel {
146152
} else {
147153
vocab_map_[line] = id;
148154
}
149-
special_tokens_.Add(std::move(line_32), id);
155+
ORTX_RETURN_IF_ERROR(special_tokens_.Add(std::move(line_32), id));
150156
}
151157
}
152158

0 commit comments

Comments
 (0)