Skip to content

Commit 378bbef

Browse files
sayanshaw24Sayan Shaw
andauthored
Add Python API HF Embedded JSON tokenizer support (#860)
* add python api hf embdedded json tokenizer support * remove xlmrobertatokenizer test as it is not on HF --------- Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
1 parent 1a21d45 commit 378bbef

File tree

3 files changed

+289
-26
lines changed

3 files changed

+289
-26
lines changed

onnxruntime_extensions/_cuops.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,16 @@ def get_outputs(cls):
491491
]
492492

493493

494+
class HfJsonTokenizer(CustomOp):
495+
@classmethod
496+
def get_inputs(cls):
497+
return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]
498+
499+
@classmethod
500+
def get_outputs(cls):
501+
return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
502+
503+
494504
# TODO: have a C++ impl.
495505
def _argsort_op(x, dim):
496506
d = numpy.argsort(x, dim)
@@ -544,4 +554,4 @@ def build_graph(cls, op_class, *args, **kwargs):
544554

545555
@staticmethod
546556
def get_op_class(op_type):
547-
return globals()[op_type]
557+
return globals()[op_type]

onnxruntime_extensions/cvt.py

Lines changed: 235 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,24 @@
1212
from ._hf_cvt import HFTokenizerConverter, HFTokenizerOnnxGraph # noqa
1313
from ._ortapi2 import make_onnx_model, SingleOpGraph
1414

15+
import os
16+
import numpy as np
17+
import tempfile
18+
import shutil
19+
20+
# edit environment variables to avoid protobuf version mismatch
21+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
22+
23+
from transformers.convert_slow_tokenizer import SpmConverter # noqa: E402
24+
from transformers import AutoTokenizer # noqa: E402
25+
from tokenizers import decoders, normalizers, pre_tokenizers, Regex # noqa: E402
26+
27+
28+
OrtxTokenizer = None
29+
try:
30+
from onnxruntime_extensions.pp_api import Tokenizer as OrtxTokenizer
31+
except ImportError:
32+
pass
1533

1634
_is_torch_available = False
1735
try:
@@ -24,11 +42,150 @@
2442

2543
_PRE_POST_PAIR = {'TrieTokenizer': "TrieDetokenizer"}
2644

45+
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
46+
if add_prefix_space:
47+
prepend_scheme = "always"
48+
if not getattr(original_tokenizer, "legacy", True):
49+
prepend_scheme = "first"
50+
else:
51+
prepend_scheme = "never"
52+
return prepend_scheme
53+
54+
55+
class Baichuan2Converter(SpmConverter):
56+
handle_byte_fallback = True
57+
58+
def __init__(self, original_tokenizer):
59+
super().__init__(original_tokenizer)
60+
original_tokenizer.add_prefix_space = False
61+
62+
def vocab(self, proto):
63+
vocab = [
64+
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
65+
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
66+
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
67+
]
68+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
69+
return vocab
70+
71+
def unk_id(self, proto):
72+
unk_id = 0
73+
return unk_id
74+
75+
def decoder(self, replacement, add_prefix_space):
76+
sequence = [
77+
decoders.Replace("▁", " "),
78+
decoders.ByteFallback(),
79+
decoders.Fuse(),
80+
]
81+
if add_prefix_space:
82+
sequence += [decoders.Strip(content=" ", left=1)]
83+
return decoders.Sequence(sequence)
84+
85+
def normalizer(self, proto):
86+
if getattr(self.original_tokenizer, "legacy", True):
87+
sequence = []
88+
if getattr(self.original_tokenizer, "add_prefix_space", True):
89+
sequence += [normalizers.Prepend(prepend="▁")]
90+
sequence += [normalizers.Replace(pattern=" ", content="▁")]
91+
return normalizers.Sequence(sequence)
92+
return None # non-legacy, no normalizer
93+
94+
def pre_tokenizer(self, replacement, add_prefix_space):
95+
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
96+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
97+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
98+
else:
99+
return super().pre_tokenizer(replacement, add_prefix_space)
100+
101+
102+
class ChatGlmConverter(SpmConverter):
103+
def normalizer(self, proto):
104+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
105+
_normalizers = [
106+
normalizers.Strip(left=False, right=True), # stripping is important
107+
normalizers.Replace(Regex(" {2,}"), "▁"),
108+
]
109+
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
110+
111+
def pre_tokenizer(self, replacement, add_prefix_space):
112+
prepend_scheme = "always"
113+
if hasattr(self.original_tokenizer, "legacy") and not self.original_tokenizer.legacy:
114+
prepend_scheme = "first"
115+
return pre_tokenizers.Metaspace(
116+
replacement=replacement, add_prefix_space=add_prefix_space, prepend_scheme=prepend_scheme
117+
)
118+
119+
120+
JSON_TOKEN_CONVERTERS = {
121+
"BaichuanTokenizer": Baichuan2Converter,
122+
"ChatGLMTokenizer": ChatGlmConverter,
123+
}
124+
125+
# Save tokenizer JSON files using HuggingFace AutoTokenizer
126+
def convert_tokenizer(model_path, output_dir):
127+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
128+
if output_dir is None:
129+
if os.path.isdir(model_path):
130+
output_dir = model_path
131+
else:
132+
# create a temporary directory
133+
output_dir = tempfile.mkdtemp()
134+
tokenizer.save_pretrained(output_dir)
135+
json_path = os.path.join(output_dir, "tokenizer.json")
136+
137+
if type(tokenizer).__name__ in JSON_TOKEN_CONVERTERS:
138+
GenericSpmConverter = JSON_TOKEN_CONVERTERS[type(tokenizer).__name__]
139+
140+
converted = GenericSpmConverter(tokenizer).converted()
141+
converted.save(json_path)
142+
print(f"**Tokenizer saved to {json_path}")
143+
return output_dir
144+
145+
# Validate tokenizer files downloaded from memory
146+
def validate_tokenizer(model_path, output_dir):
147+
test_sentence = "I like walking my cute dog\n and\x17 then, 生活的真谛是 \t\t\t\t \n\n61"
148+
if OrtxTokenizer is None:
149+
print("onnxruntime_extensions package was built with C API enabled, skipping tokenization test")
150+
ortx_tokenizer = OrtxTokenizer(output_dir)
151+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
152+
expected_ids = tokenizer(test_sentence, return_tensors="np")["input_ids"]
153+
ortx_ids = np.asarray(ortx_tokenizer.tokenize(test_sentence))
154+
assert np.array_equal(expected_ids[0], ortx_ids), f"Tokenization mismatch: {expected_ids[0]} != {ortx_ids}"
155+
print("Tokenization test passed")
156+
157+
# Download tokenizer JSON files from memory
158+
def download_tokenizer(tokenizer_dir, output_dir):
159+
try:
160+
from transformers.utils import cached_file
161+
162+
resolved_full_file = cached_file(tokenizer_dir, "tokenizer.json")
163+
resolved_config_file = cached_file(tokenizer_dir, "tokenizer_config.json")
164+
except ImportError:
165+
raise ValueError(f"Directory '{tokenizer_dir}' not found and transformers is not available")
166+
if not os.path.exists(resolved_full_file):
167+
raise FileNotFoundError(f"Downloaded HF file '{resolved_full_file}' cannot be found")
168+
if os.path.dirname(resolved_full_file) != os.path.dirname(resolved_config_file):
169+
raise FileNotFoundError(
170+
f"Downloaded HF files '{resolved_full_file}' " f"and '{resolved_config_file}' are not in the same directory"
171+
)
172+
173+
if output_dir is None or len(output_dir) == 0:
174+
output_dir = os.path.dirname(resolved_full_file)
175+
print(f"Using {output_dir} as output directory")
176+
return output_dir
177+
else:
178+
# copy the files to the output directory
179+
shutil.copy(resolved_full_file, output_dir)
180+
shutil.copy(resolved_config_file, output_dir)
181+
return output_dir
182+
27183

28184
def gen_processing_models(processor: Union[str, object],
29185
pre_kwargs: dict = None,
30186
post_kwargs: dict = None,
31187
opset: int = None,
188+
schema_v2: bool = False,
32189
**kwargs):
33190
"""
34191
Generate the pre- and post-processing ONNX model, basing on the name or HF class.
@@ -47,6 +204,9 @@ def gen_processing_models(processor: Union[str, object],
47204
Keyword arguments for generating the post-processing model
48205
opset: int
49206
the target opset version of the model
207+
schema_v2: bool
208+
the flag for using embedded tokenizer files; this option leverages the blob-loading functionality
209+
which loads HF tokenizers from memory rather than using the tokenizer files in HF JSON format.
50210
kwargs:
51211
The additional arguments for generating models
52212
@@ -58,39 +218,89 @@ def gen_processing_models(processor: Union[str, object],
58218
if pre_kwargs is None and post_kwargs is None:
59219
raise ValueError(
60220
"Either pre_kwargs or post_kwargs should be provided. None means no processing graph output.")
61-
if isinstance(processor, str):
221+
222+
# If true, we get the tokenizer JSON files by either downloading from cache or using HuggingFace AutoTokenizer
223+
# to convert them, and then create an ONNX model with the JSON files as strings in the model attributes (attrs).
224+
if schema_v2:
225+
model_name = processor if isinstance(processor, str) else type(processor).__name__
226+
227+
converted_tokenizer = {"Baichuan2", "chatglm"}
228+
need_convert = False
229+
for token in converted_tokenizer:
230+
if model_name.find(token) != -1:
231+
need_convert = True
232+
break
233+
234+
if need_convert:
235+
model_dir = convert_tokenizer(model_name)
236+
validate_tokenizer(model_name, None)
237+
else:
238+
model_dir = download_tokenizer(model_name, None)
239+
240+
# Load the content of tokenizer.json into a string
241+
with open(f"{model_dir}/tokenizer.json", "r", encoding="utf-8") as f:
242+
tokenizer_vocab = f.read()
243+
244+
# Load the content of tokenizer_config.json into a string
245+
with open(f"{model_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
246+
tokenizer_config = f.read()
247+
248+
# Create an ONNX model with these JSON file strings in attrs
62249
g_pre, g_post = (None, None)
63-
if pre_kwargs:
64-
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
65-
if post_kwargs:
250+
if pre_kwargs is not None:
251+
# Add tokenizer_vocab and tokenizer_config to the kwargs
252+
# so they are added to attrs in build_graph
253+
pre_kwargs['tokenizer_vocab'] = tokenizer_vocab
254+
pre_kwargs['tokenizer_config'] = tokenizer_config
255+
g_pre = SingleOpGraph.build_graph("HfJsonTokenizer", **pre_kwargs)
256+
if post_kwargs is not None:
66257
if pre_kwargs is None:
67258
cls_name = processor
68259
else:
69260
if processor not in _PRE_POST_PAIR:
70261
raise RuntimeError(
71262
f"Cannot locate the post processing operator name from {processor}")
72263
cls_name = _PRE_POST_PAIR[processor]
264+
# Add tokenizer_vocab and tokenizer_config to the kwargs
265+
# so they are added to attrs in build_graph
266+
post_kwargs['tokenizer_vocab'] = tokenizer_vocab
267+
post_kwargs['tokenizer_config'] = tokenizer_config
73268
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
74269
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
75-
76-
cls_name = type(processor).__name__
77-
if cls_name == "WhisperProcessor":
78-
if WhisperDataProcGraph is None:
79-
raise ValueError(
80-
"The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
81-
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
82-
pre_m = _converter.pre_processing(
83-
**pre_kwargs) if pre_kwargs is not None else None
84-
post_m = _converter.post_processing(
85-
**post_kwargs) if post_kwargs is not None else None
86-
return pre_m, post_m
87-
elif HFTokenizerOnnxGraph.is_supported(processor):
88-
_converter = HFTokenizerOnnxGraph(processor)
89-
pre_g = _converter.pre_processing(
90-
**pre_kwargs) if pre_kwargs is not None else None
91-
post_g = _converter.post_processing(
92-
**post_kwargs) if post_kwargs is not None else None
93-
return make_onnx_model(pre_g) if pre_g else None, \
94-
make_onnx_model(post_g) if post_g else None
95270
else:
96-
raise ValueError(f"Unsupported processor/tokenizer: {cls_name}")
271+
if isinstance(processor, str):
272+
g_pre, g_post = (None, None)
273+
if pre_kwargs:
274+
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
275+
if post_kwargs:
276+
if pre_kwargs is None:
277+
cls_name = processor
278+
else:
279+
if processor not in _PRE_POST_PAIR:
280+
raise RuntimeError(
281+
f"Cannot locate the post processing operator name from {processor}")
282+
cls_name = _PRE_POST_PAIR[processor]
283+
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
284+
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
285+
286+
cls_name = type(processor).__name__
287+
if cls_name == "WhisperProcessor":
288+
if WhisperDataProcGraph is None:
289+
raise ValueError(
290+
"The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
291+
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
292+
pre_m = _converter.pre_processing(
293+
**pre_kwargs) if pre_kwargs is not None else None
294+
post_m = _converter.post_processing(
295+
**post_kwargs) if post_kwargs is not None else None
296+
return pre_m, post_m
297+
elif HFTokenizerOnnxGraph.is_supported(processor):
298+
_converter = HFTokenizerOnnxGraph(processor)
299+
pre_g = _converter.pre_processing(
300+
**pre_kwargs) if pre_kwargs is not None else None
301+
post_g = _converter.post_processing(
302+
**post_kwargs) if post_kwargs is not None else None
303+
return make_onnx_model(pre_g) if pre_g else None, \
304+
make_onnx_model(post_g) if post_g else None
305+
else:
306+
raise ValueError(f"Unsupported processor/tokenizer: {cls_name}")

test/test_embedded_tokenizer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import unittest
4+
5+
import numpy as np
6+
from transformers import AutoTokenizer, GPT2Tokenizer
7+
from onnxruntime_extensions import OrtPyFunction, gen_processing_models, ort_inference
8+
9+
10+
class TestEmbeddedTokenizer(unittest.TestCase):
11+
def test_clip_tokenizer(self):
12+
tokenizer = AutoTokenizer.from_pretrained(
13+
"openai/clip-vit-base-patch32", use_fast=False)
14+
text = """
15+
1. Testing long text with multiple lines to check newline handling
16+
2. As well as words with apostrophes such as you're, i'm, don't, etc.
17+
3. And weird characters such as . , ~ ? ( ) " [ ] ! : - .
18+
"""
19+
ids = tokenizer.encode(text, return_tensors="np")
20+
21+
ort_tok = OrtPyFunction.from_model(gen_processing_models(
22+
tokenizer,
23+
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0],
24+
schema_v2=True)
25+
actual_ids = ort_tok([text])[0]
26+
np.testing.assert_array_equal(ids, actual_ids)
27+
28+
def test_gpt2_tokenizer(self):
29+
tokenizer = GPT2Tokenizer.from_pretrained(
30+
"Xenova/gpt-4", use_fast=False)
31+
text = "Testing words with apostrophes such as you're, i'm, don't, etc."
32+
ids = tokenizer.encode(text, return_tensors="np")
33+
34+
ort_tok = OrtPyFunction.from_model(gen_processing_models(
35+
tokenizer,
36+
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0],
37+
schema_v2=True)
38+
actual_ids = ort_tok([text])[0]
39+
np.testing.assert_array_equal(ids, actual_ids)
40+
41+
42+
if __name__ == '__main__':
43+
unittest.main()

0 commit comments

Comments
 (0)