Skip to content

Commit f8f3ae9

Browse files
authored
Support fast tokenizer type in JSON tokenizer (#876)
* Support fast tokenizer type too * more log * Update tokenizer dictionary entry name
1 parent e8bf5a9 commit f8f3ae9

File tree

5 files changed

+21
-6
lines changed

5 files changed

+21
-6
lines changed

.pipelines/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ stages:
740740
steps:
741741
- script: |
742742
cd $(Build.BinariesDirectory)
743-
git clone https://github.com/emscripten-core/emsdk
743+
git clone https://github.com/emscripten-core/emsdk --depth 1 --branch 3.1.74
744744
emsdk/emsdk install latest
745745
emsdk/emsdk activate latest
746746
displayName: Setup emscripten pipeline

operators/tokenizer/tokenizer_jsconfig.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ enum class TokenType {
1515
};
1616

1717
constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
18-
{"PreTrainedTokenizerFast", TokenType::kBPE},
18+
{"PreTrainedTokenizer", TokenType::kBPE},
1919
{"CLIPTokenizer", TokenType::kBPE},
2020
{"WhisperTokenizer", TokenType::kBPE},
2121
{"GemmaTokenizer", TokenType::kBPE},
@@ -256,10 +256,16 @@ class TokenJsonConfig final {
256256
}
257257

258258
static TokenType GetTokenType(const std::string& tok) {
259-
static const std::unordered_map<std::string, TokenType> dict {
259+
static const std::unordered_map<std::string_view, TokenType> dict {
260260
std::begin(kTokenizerDict), std::end(kTokenizerDict) };
261261

262-
auto iter = dict.find(tok);
262+
std::string_view tok_class(tok);
263+
auto pos = tok_class.find("Fast");
264+
if (pos != std::string_view::npos && pos + 4 == tok_class.size()) {
265+
tok_class.remove_suffix(4);
266+
}
267+
268+
auto iter = dict.find(tok_class);
263269
return iter == dict.end() ? TokenType::kUnknown : iter->second;
264270
}
265271

operators/tokenizer/tokenizer_op_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class JsonTokenizerOpKernel {
3333
} else if (type == TokenType::kBPE) {
3434
tokenizer_ = std::make_unique<JsonFastTokenizer>();
3535
} else {
36-
return OrtxStatus(kOrtxErrorCorruptData, "Unknown tokenizer type");
36+
return OrtxStatus(kOrtxErrorCorruptData, "Unknown tokenizer type" + cfg.tokenizer_class_);
3737
}
3838

3939
return std::visit([&](auto& ptr) { return ptr->Load(cfg); }, tokenizer_);

shared/api/tokenizer_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
6767
return status;
6868
}
6969

70-
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
70+
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class: " + tok_config_->tokenizer_class_);
7171
}
7272

7373
OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {

test/test_pp_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ def test_Qwen_QVQ_tokenizer(self):
140140
ortx_inputs = tokenizer.tokenize(test_sentence)
141141
np.testing.assert_array_equal(ortx_inputs, inputs)
142142

143+
def test_Phi4_tokenizer(self):
144+
model_id = "/g/phi-x-12202024"
145+
test_sentence = [self.tokenizer_test_sentence]
146+
hf_enc = AutoTokenizer.from_pretrained(model_id)
147+
inputs = hf_enc(test_sentence)["input_ids"]
148+
tokenizer = pp_api.Tokenizer(model_id)
149+
ortx_inputs = tokenizer.tokenize(test_sentence)
150+
np.testing.assert_array_equal(ortx_inputs, inputs)
151+
143152

144153
if __name__ == "__main__":
145154
unittest.main()

0 commit comments

Comments
 (0)