Skip to content

Commit 5eda73c

Browse files
committed
Add regex loading from tokenizer.json and code refinement
1 parent 378bbef commit 5eda73c

File tree

10 files changed

+128
-76
lines changed

10 files changed

+128
-76
lines changed

onnxruntime_extensions/pp_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def __init__(self, tokenizer_dir):
4949
self.tokenizer = create_tokenizer(tokenizer_dir)
5050

5151
def tokenize(self, text):
52+
if isinstance(text, (list, tuple)):
53+
return batch_tokenize(self.tokenizer, text)
5254
return batch_tokenize(self.tokenizer, [text])[0]
5355

5456
def detokenize(self, tokens):
55-
return batch_detokenize(self.tokenizer, [tokens])[0]
57+
return batch_detokenize(self.tokenizer, [tokens])
5658

5759
def __del__(self):
5860
if delete_object and self.tokenizer:

operators/tokenizer/bpe_kernels.cc

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
262262

263263
// Parse input
264264
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
265-
bpe::TokenWithRegularExp regcmp;
265+
bpe::PreTokenizerWithRegEx reg_splitter;
266266

267267
for (auto& seg_id : special_token_split_res) {
268268
if (static_cast<int64_t>(res.size()) >= max_length) break;
@@ -274,7 +274,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
274274

275275
// Note: keep ptr to make sure the string_view is valid in the following process
276276
std::u32string str(seg_id.first);
277-
regcmp.Set(str.c_str());
277+
reg_splitter.Set(str.c_str());
278278

279279
size_t offset = 0;
280280
OffsetMappingType offset_mapping;
@@ -287,14 +287,8 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
287287
}
288288

289289
while (static_cast<int64_t>(res.size()) < max_length) {
290-
std::string regex_expr = "";
291-
if (ModelName() == kModel_Llama){
292-
regex_expr = regcmp.LLAMA_REGEX_PATTERN;
293-
} else {
294-
// default to GPT2 regex
295-
regex_expr = regcmp.GPT2_REGEX_PATTERN;
296-
}
297-
auto [b, tok] = regcmp.GetNextToken(regex_expr);
290+
std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName());
291+
auto [b, tok] = reg_splitter.GetNextToken(regex_expr);
298292

299293
if (!b) break;
300294

operators/tokenizer/bpe_tokenizer_model.hpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,58 @@ class BpeModel {
4444
}
4545
}
4646

47+
OrtxStatus LoadPreTokenizer(const json& bpe_model) {
48+
auto node_pre_tokenizer = bpe_model.find("pre_tokenizer");
49+
if (node_pre_tokenizer == bpe_model.end() || node_pre_tokenizer->is_null()) {
50+
return {};
51+
}
52+
53+
auto iter_type = node_pre_tokenizer->find("type");
54+
if (iter_type == node_pre_tokenizer->end() || iter_type->is_null()) {
55+
return {};
56+
}
57+
58+
if (iter_type->get<std::string>() != "Sequence") {
59+
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
60+
}
61+
62+
auto iter_node_list = node_pre_tokenizer->find("pretokenizers");
63+
64+
if (iter_node_list == node_pre_tokenizer->end() || iter_node_list->is_null()) {
65+
return {};
66+
}
67+
68+
for (const auto& node : *iter_node_list) {
69+
auto iter_type = node.find("type");
70+
if (iter_type == node.end() || iter_type->is_null()) {
71+
continue; // ignore unknown pre-tokenizer type
72+
}
73+
74+
75+
auto pre_type = iter_type->get<std::string>();
76+
if (pre_type == "Split") {
77+
auto iter_pattern = node.find("pattern");
78+
if (iter_pattern == node.end() || iter_pattern->is_null()) {
79+
continue;
80+
}
81+
82+
auto regex_str = iter_pattern->find("Regex");
83+
if (regex_str == iter_pattern->end() || regex_str->is_null()) {
84+
continue;
85+
}
86+
87+
pre_tokenizer_regex_ = regex_str->get<std::string>();
88+
} else if (pre_type == "ByteLevel") {
89+
; // need to add more flag support here in the future
90+
}
91+
else {
92+
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
93+
}
94+
}
95+
96+
return {};
97+
}
98+
4799
OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
48100
const char* special_tokens, bool spm_converted) {
49101
nlohmann::json tok_json;
@@ -121,6 +173,8 @@ class BpeModel {
121173
}
122174

123175
OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
176+
ORTX_RETURN_IF_ERROR(LoadPreTokenizer(bpe_model));
177+
124178
const json& vocab_json = bpe_model["vocab"];
125179
const json& merges_json = bpe_model["merges"];
126180
vocab_json.get_to(vocab_map_);
@@ -358,6 +412,19 @@ class BpeModel {
358412

359413
const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }
360414

415+
std::string GetPreTokenizerRegex(const std::string& model_name) const {
416+
if (!pre_tokenizer_regex_.empty()) {
417+
return pre_tokenizer_regex_;
418+
}
419+
420+
if (model_name == "Llama") {
421+
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
422+
}
423+
424+
// by default, use the GPT2 pretokenizer regex
425+
return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN;
426+
}
427+
361428
private:
362429
struct BpeNode {
363430
uint32_t id;
@@ -379,6 +446,7 @@ class BpeModel {
379446
uint32_t unk_id_ = (std::numeric_limits<uint32_t>::max)();
380447
bpe::SpecialTokenMap special_tokens_;
381448
TrieTree<char32_t> added_tokens_;
449+
std::string pre_tokenizer_regex_;
382450
};
383451

384452
} // namespace ort_extensions

operators/tokenizer/bpe_utils.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,12 @@ class SpecialTokenMap {
9797
std::unordered_map<ustring, int> token_map_;
9898
};
9999

100-
class TokenWithRegularExp {
100+
class PreTokenizerWithRegEx {
101101
public:
102+
static constexpr const char* GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
103+
static constexpr const char* LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
104+
static constexpr const char* LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
105+
102106
void Set(std::u32string_view val) {
103107
m_text = val;
104108
}
@@ -115,10 +119,6 @@ class TokenWithRegularExp {
115119
return {false, {}};
116120
}
117121

118-
const std::string GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
119-
const std::string LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
120-
const std::string LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
121-
122122
public:
123123

124124
// Although we have RegexMatchGeneral which performs regex matching given any general regex string

operators/tokenizer/tokenizer_jsconfig.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
2626
{"GPT2Tokenizer", TokenType::kBPE},
2727
{"Qwen2Tokenizer", TokenType::kBPE},
2828
{"BaichuanTokenizer", TokenType::kBPE},
29+
{"GPTNeoXTokenizer", TokenType::kBPE},
2930

3031
{"", TokenType::kUnigram},
3132
{"T5Tokenizer", TokenType::kUnigram},

pyop/py_c_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
121121
OrtxTokenizer* tokenizer = nullptr;
122122
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
123123
if (err != kOrtxOK) {
124-
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
124+
throw std::runtime_error(std::string("Failed to create tokenizer\n") + OrtxGetLastErrorMessage());
125125
}
126126
return reinterpret_cast<std::uintptr_t>(tokenizer);
127127
},

shared/api/tokenizer_impl.cc

Lines changed: 28 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,15 @@
1111

1212
namespace ort_extensions {
1313

14-
std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
15-
"PreTrainedTokenizerFast",
16-
"CLIPTokenizer",
17-
"WhisperTokenizer",
18-
"GemmaTokenizer",
19-
"LlamaTokenizer",
20-
"Phi3Tokenizer",
21-
"CodeLlamaTokenizer",
22-
"CodeGenTokenizer",
23-
"GPT2Tokenizer",
24-
"Qwen2Tokenizer",
25-
"BaichuanTokenizer"
26-
};
27-
28-
std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
29-
"XLMRobertaTokenizer",
30-
"T5Tokenizer",
31-
"ChatGLMTokenizer"
32-
};
3314

3415
TokenizerImpl::TokenizerImpl()
3516
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
3617
TokenizerImpl::~TokenizerImpl() {};
3718

3819
OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
39-
if (tok_config_->tokenizer_class_.empty() ||
40-
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {
20+
21+
auto type = TokenJsonConfig::GetTokenType(tok_config_->tokenizer_class_);
22+
if (type == TokenType::kUnigram) {
4123
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
4224
auto status = tokenizer->Load(*tok_config_);
4325
if (!status.IsOk()) {
@@ -53,42 +35,39 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
5335
tokenizer_ = std::move(tokenizer);
5436
detokenizer_ = std::move(detok);
5537
}
56-
5738
return status;
58-
}
59-
60-
if (!supported_bpe_models_.count(tok_config_->tokenizer_class_)) {
61-
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
62-
}
63-
64-
auto tokenizer = std::make_unique<JsonFastTokenizer>();
65-
auto fx_load = &JsonFastTokenizer::Load;
66-
if (blob == nullptr) {
67-
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
68-
// vocab file is checked in TokenJsonConfig::Load
69-
if (vocab_file_path.extension() != ".json") {
70-
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
39+
} else if (type == TokenType::kBPE) {
40+
auto tokenizer = std::make_unique<JsonFastTokenizer>();
41+
auto fx_load = &JsonFastTokenizer::Load;
42+
if (blob == nullptr) {
43+
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
44+
// vocab file is checked in TokenJsonConfig::Load
45+
if (vocab_file_path.extension() != ".json") {
46+
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
47+
}
48+
} else {
49+
if (blob->raw_model_blob_len > 0) {
50+
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
51+
}
7152
}
72-
} else {
73-
if (blob->raw_model_blob_len > 0) {
74-
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
53+
54+
auto status = (tokenizer.get()->*fx_load)(*tok_config_);
55+
if (!status.IsOk()) {
56+
return status;
7557
}
76-
}
7758

78-
auto status = (tokenizer.get()->*fx_load)(*tok_config_);
79-
if (!status.IsOk()) {
80-
return status;
81-
}
59+
auto detok = std::make_unique<BpeStreamingDecoder>();
60+
status = detok->Load(tok_config_, *tokenizer);
8261

83-
auto detok = std::make_unique<BpeStreamingDecoder>();
84-
status = detok->Load(tok_config_, *tokenizer);
62+
if (status.IsOk()) {
63+
tokenizer_ = std::move(tokenizer);
64+
detokenizer_ = std::move(detok);
65+
}
8566

86-
if (status.IsOk()) {
87-
tokenizer_ = std::move(tokenizer);
88-
detokenizer_ = std::move(detok);
67+
return status;
8968
}
9069

91-
return status;
70+
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
9271
}
9372

9473
OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {

shared/api/tokenizer_impl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ namespace ort_extensions {
1515

1616
class TokenizerImpl : public OrtxObjectImpl {
1717
public:
18-
static std::set<std::string> supported_bpe_models_;
19-
static std::set<std::string> supported_ugm_models_;
20-
2118
TokenizerImpl();
2219
virtual ~TokenizerImpl();
2320

test/pp_api_test/test_tokenizer.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ TEST(CApiTest, StreamApiTest) {
6767

6868
TEST(OrtxTokenizerTest, RegexTest) {
6969
std::u32string str = U"CAN'T \r\n 2413m";
70-
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
70+
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();
7171

7272
std::vector<std::u32string> res;
7373
std::vector<std::u32string> out_tokens = {U"CAN", U"'T", U" \r\n", U" ", U"241", U"3", U"m"};
@@ -91,7 +91,7 @@ TEST(OrtxTokenizerTest, RegexMatchSTDTest) {
9191
std::vector<std::u32string> input_strings = {U"not its, or IT'S, but it's",
9292
U" ",
9393
U"AbCd"};
94-
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
94+
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();
9595

9696
std::vector<std::vector<std::u32string>> res_vector;
9797
std::vector<std::vector<std::u32string>> out_tokens = {{U"'s"},
@@ -118,7 +118,7 @@ TEST(OrtxTokenizerTest, WrapStandaloneCategoriesTest) {
118118
"\\p{rn}\\p{L}\\p{N}\\p{L}",
119119
"\\p{Z}*[\\p{rn}]+",
120120
"\\p{Z}+"};
121-
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
121+
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();
122122

123123
std::vector<std::string> res;
124124
std::vector<std::string> out_regex = {"[^\\p{rn}\\p{L}\\p{N}]?[\\p{L}]+",
@@ -152,7 +152,7 @@ TEST(OrtxTokenizerTest, RegexMatchGeneralTest) {
152152
U"241356m",
153153
U"Ich liebe München <3 \r\n ",
154154
U"生活的真谛是"};
155-
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
155+
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();
156156

157157
std::vector<std::vector<std::u32string>> res_vector;
158158
std::vector<std::vector<std::u32string>> out_tokens = {{U"CAN", U"'T", U"", U""},

test/test_pp_api.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99

1010
is_pp_api_available = False
11+
hf_token_id = None
1112
try:
12-
from transformers import AutoImageProcessor
13+
from transformers import AutoImageProcessor, AutoTokenizer
1314
from onnxruntime_extensions import pp_api
1415
is_pp_api_available = True
16+
hf_token_id = os.environ.get("HF_TOKEN", None)
1517
except ImportError:
1618
pass
1719

@@ -46,7 +48,6 @@ def setUpClass(cls):
4648
else:
4749
cls.temp_dir = tempfile.mkdtemp()
4850
print(f"Created temp dir: {cls.temp_dir}")
49-
cls.token_id = os.environ.get("HF_TOKEN", None)
5051

5152
def test_CLIP_image_processing(self):
5253
model_id = "openai/clip-vit-large-patch14"
@@ -76,6 +77,7 @@ def test_CLIP_image_processing(self):
7677
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
7778
a_image.save(f"{self.temp_dir}/CLIP_a_{i}.png")
7879

80+
@unittest.skipIf(hf_token_id is None, "HF_TOKEN is not available")
7981
def test_llama3_2_image_processing(self):
8082
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
8183

@@ -91,7 +93,7 @@ def test_llama3_2_image_processing(self):
9193
"test/data/processor/exceltable.png"]
9294
(image, image2, image3) = [Image.open(f) for f in image_list]
9395

94-
processor = AutoImageProcessor.from_pretrained(model_id, token=TestPPAPI.token_id)
96+
processor = AutoImageProcessor.from_pretrained(model_id, token=hf_token_id)
9597
inputs = processor.preprocess(
9698
[image, image2, image3], return_tensors="np")
9799
print({k: v.shape if k == "pixel_values" else v for k, v in inputs.items()})
@@ -114,6 +116,15 @@ def test_llama3_2_image_processing(self):
114116
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
115117
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")
116118

119+
def test_OLMa_tokenizer(self):
120+
test_sentence = ["I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61" + " |||IP_ADDRESS|||"]
121+
model_id = "amd/AMD-OLMo-1B-SFT-DPO"
122+
hf_enc = AutoTokenizer.from_pretrained(model_id)
123+
inputs = hf_enc(test_sentence)["input_ids"]
124+
tokenizer = pp_api.Tokenizer(model_id)
125+
ortx_inputs = tokenizer.tokenize(test_sentence)
126+
# self.assertEqual(inputs, ortx_inputs)
127+
np.testing.assert_array_equal(ortx_inputs, inputs)
117128

118129
if __name__ == '__main__':
119130
unittest.main()

0 commit comments

Comments
 (0)