@@ -44,6 +44,58 @@ class BpeModel {
4444 }
4545 }
4646
47+ OrtxStatus LoadPreTokenizer (const json& bpe_model) {
48+ auto node_pre_tokenizer = bpe_model.find (" pre_tokenizer" );
49+ if (node_pre_tokenizer == bpe_model.end () || node_pre_tokenizer->is_null ()) {
50+ return {};
51+ }
52+
53+ auto iter_type = node_pre_tokenizer->find (" type" );
54+ if (iter_type == node_pre_tokenizer->end () || iter_type->is_null ()) {
55+ return {};
56+ }
57+
58+ if (iter_type->get <std::string>() != " Sequence" ) {
59+ return {kOrtxErrorNotImplemented , " Unsupported pretokenizer type!" };
60+ }
61+
62+ auto iter_node_list = node_pre_tokenizer->find (" pretokenizers" );
63+
64+ if (iter_node_list == node_pre_tokenizer->end () || iter_node_list->is_null ()) {
65+ return {};
66+ }
67+
68+ for (const auto & node : *iter_node_list) {
69+ auto iter_type = node.find (" type" );
70+ if (iter_type == node.end () || iter_type->is_null ()) {
71+ continue ; // ignore unknown pre-tokenizer type
72+ }
73+
74+
75+ auto pre_type = iter_type->get <std::string>();
76+ if (pre_type == " Split" ) {
77+ auto iter_pattern = node.find (" pattern" );
78+ if (iter_pattern == node.end () || iter_pattern->is_null ()) {
79+ continue ;
80+ }
81+
82+ auto regex_str = iter_pattern->find (" Regex" );
83+ if (regex_str == iter_pattern->end () || regex_str->is_null ()) {
84+ continue ;
85+ }
86+
87+ pre_tokenizer_regex_ = regex_str->get <std::string>();
88+ } else if (pre_type == " ByteLevel" ) {
89+ ; // need to add more flag support here in the future
90+ }
91+ else {
92+ return {kOrtxErrorNotImplemented , " Unsupported pretokenizer type!" };
93+ }
94+ }
95+
96+ return {};
97+ }
98+
4799 OrtxStatus Load (std::istream& vocab_stream, std::istream& merges_stream, const char * unk_token,
48100 const char * special_tokens, bool spm_converted) {
49101 nlohmann::json tok_json;
@@ -121,6 +173,8 @@ class BpeModel {
121173 }
122174
123175 OrtxStatus Load (const json& bpe_model, const char * /* special_tokens */ , bool spm_converted) {
176+ ORTX_RETURN_IF_ERROR (LoadPreTokenizer (bpe_model));
177+
124178 const json& vocab_json = bpe_model[" vocab" ];
125179 const json& merges_json = bpe_model[" merges" ];
126180 vocab_json.get_to (vocab_map_);
@@ -358,6 +412,19 @@ class BpeModel {
358412
359413 const std::string& GetEndOfWordSuffix () const { return end_of_word_suffix_; }
360414
415+ std::string GetPreTokenizerRegex (const std::string& model_name) const {
416+ if (!pre_tokenizer_regex_.empty ()) {
417+ return pre_tokenizer_regex_;
418+ }
419+
420+ if (model_name == " Llama" ) {
421+ return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
422+ }
423+
424+ // by default, use the GPT2 pretokenizer regex
425+ return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN;
426+ }
427+
361428 private:
362429 struct BpeNode {
363430 uint32_t id;
@@ -379,6 +446,7 @@ class BpeModel {
379446 uint32_t unk_id_ = (std::numeric_limits<uint32_t >::max)();
380447 bpe::SpecialTokenMap special_tokens_;
381448 TrieTree<char32_t > added_tokens_;
449+ std::string pre_tokenizer_regex_;
382450};
383451
384452} // namespace ort_extensions
0 commit comments