1212from ._hf_cvt import HFTokenizerConverter , HFTokenizerOnnxGraph # noqa
1313from ._ortapi2 import make_onnx_model , SingleOpGraph
1414
15+ import os
16+ import numpy as np
17+ import tempfile
18+ import shutil
19+
20+ # edit environment variables to avoid protobuf version mismatch
21+ os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ] = "python"
22+
23+ from transformers .convert_slow_tokenizer import SpmConverter # noqa: E402
24+ from transformers import AutoTokenizer # noqa: E402
25+ from tokenizers import decoders , normalizers , pre_tokenizers , Regex # noqa: E402
26+
27+
28+ OrtxTokenizer = None
29+ try :
30+ from onnxruntime_extensions .pp_api import Tokenizer as OrtxTokenizer
31+ except ImportError :
32+ pass
1533
1634_is_torch_available = False
1735try :
2442
2543_PRE_POST_PAIR = {'TrieTokenizer' : "TrieDetokenizer" }
2644
45+ def _get_prepend_scheme (add_prefix_space : bool , original_tokenizer ) -> str :
46+ if add_prefix_space :
47+ prepend_scheme = "always"
48+ if not getattr (original_tokenizer , "legacy" , True ):
49+ prepend_scheme = "first"
50+ else :
51+ prepend_scheme = "never"
52+ return prepend_scheme
53+
54+
55+ class Baichuan2Converter (SpmConverter ):
56+ handle_byte_fallback = True
57+
58+ def __init__ (self , original_tokenizer ):
59+ super ().__init__ (original_tokenizer )
60+ original_tokenizer .add_prefix_space = False
61+
62+ def vocab (self , proto ):
63+ vocab = [
64+ (self .original_tokenizer .convert_ids_to_tokens (0 ), 0.0 ),
65+ (self .original_tokenizer .convert_ids_to_tokens (1 ), 0.0 ),
66+ (self .original_tokenizer .convert_ids_to_tokens (2 ), 0.0 ),
67+ ]
68+ vocab += [(piece .piece , piece .score ) for piece in proto .pieces [3 :]]
69+ return vocab
70+
71+ def unk_id (self , proto ):
72+ unk_id = 0
73+ return unk_id
74+
75+ def decoder (self , replacement , add_prefix_space ):
76+ sequence = [
77+ decoders .Replace ("▁" , " " ),
78+ decoders .ByteFallback (),
79+ decoders .Fuse (),
80+ ]
81+ if add_prefix_space :
82+ sequence += [decoders .Strip (content = " " , left = 1 )]
83+ return decoders .Sequence (sequence )
84+
85+ def normalizer (self , proto ):
86+ if getattr (self .original_tokenizer , "legacy" , True ):
87+ sequence = []
88+ if getattr (self .original_tokenizer , "add_prefix_space" , True ):
89+ sequence += [normalizers .Prepend (prepend = "▁" )]
90+ sequence += [normalizers .Replace (pattern = " " , content = "▁" )]
91+ return normalizers .Sequence (sequence )
92+ return None # non-legacy, no normalizer
93+
94+ def pre_tokenizer (self , replacement , add_prefix_space ):
95+ if not getattr (self .original_tokenizer , "legacy" , True ): # non-legacy, we need a replace
96+ prepend_scheme = _get_prepend_scheme (add_prefix_space , self .original_tokenizer )
97+ return pre_tokenizers .Metaspace (replacement = replacement , prepend_scheme = prepend_scheme , split = False )
98+ else :
99+ return super ().pre_tokenizer (replacement , add_prefix_space )
100+
101+
102+ class ChatGlmConverter (SpmConverter ):
103+ def normalizer (self , proto ):
104+ precompiled_charsmap = proto .normalizer_spec .precompiled_charsmap
105+ _normalizers = [
106+ normalizers .Strip (left = False , right = True ), # stripping is important
107+ normalizers .Replace (Regex (" {2,}" ), "▁" ),
108+ ]
109+ return normalizers .Sequence ([normalizers .Precompiled (precompiled_charsmap )] + _normalizers )
110+
111+ def pre_tokenizer (self , replacement , add_prefix_space ):
112+ prepend_scheme = "always"
113+ if hasattr (self .original_tokenizer , "legacy" ) and not self .original_tokenizer .legacy :
114+ prepend_scheme = "first"
115+ return pre_tokenizers .Metaspace (
116+ replacement = replacement , add_prefix_space = add_prefix_space , prepend_scheme = prepend_scheme
117+ )
118+
119+
120+ JSON_TOKEN_CONVERTERS = {
121+ "BaichuanTokenizer" : Baichuan2Converter ,
122+ "ChatGLMTokenizer" : ChatGlmConverter ,
123+ }
124+
125+ # Save tokenizer JSON files using HuggingFace AutoTokenizer
126+ def convert_tokenizer (model_path , output_dir ):
127+ tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
128+ if output_dir is None :
129+ if os .path .isdir (model_path ):
130+ output_dir = model_path
131+ else :
132+ # create a temporary directory
133+ output_dir = tempfile .mkdtemp ()
134+ tokenizer .save_pretrained (output_dir )
135+ json_path = os .path .join (output_dir , "tokenizer.json" )
136+
137+ if type (tokenizer ).__name__ in JSON_TOKEN_CONVERTERS :
138+ GenericSpmConverter = JSON_TOKEN_CONVERTERS [type (tokenizer ).__name__ ]
139+
140+ converted = GenericSpmConverter (tokenizer ).converted ()
141+ converted .save (json_path )
142+ print (f"**Tokenizer saved to { json_path } " )
143+ return output_dir
144+
145+ # Validate tokenizer files downloaded from memory
146+ def validate_tokenizer (model_path , output_dir ):
147+ test_sentence = "I like walking my cute dog\n and\x17 then, 生活的真谛是 \t \t \t \t \n \n 61"
148+ if OrtxTokenizer is None :
149+ print ("onnxruntime_extensions package was built with C API enabled, skipping tokenization test" )
150+ ortx_tokenizer = OrtxTokenizer (output_dir )
151+ tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True , use_fast = False )
152+ expected_ids = tokenizer (test_sentence , return_tensors = "np" )["input_ids" ]
153+ ortx_ids = np .asarray (ortx_tokenizer .tokenize (test_sentence ))
154+ assert np .array_equal (expected_ids [0 ], ortx_ids ), f"Tokenization mismatch: { expected_ids [0 ]} != { ortx_ids } "
155+ print ("Tokenization test passed" )
156+
157+ # Download tokenizer JSON files from memory
158+ def download_tokenizer (tokenizer_dir , output_dir ):
159+ try :
160+ from transformers .utils import cached_file
161+
162+ resolved_full_file = cached_file (tokenizer_dir , "tokenizer.json" )
163+ resolved_config_file = cached_file (tokenizer_dir , "tokenizer_config.json" )
164+ except ImportError :
165+ raise ValueError (f"Directory '{ tokenizer_dir } ' not found and transformers is not available" )
166+ if not os .path .exists (resolved_full_file ):
167+ raise FileNotFoundError (f"Downloaded HF file '{ resolved_full_file } ' cannot be found" )
168+ if os .path .dirname (resolved_full_file ) != os .path .dirname (resolved_config_file ):
169+ raise FileNotFoundError (
170+ f"Downloaded HF files '{ resolved_full_file } ' " f"and '{ resolved_config_file } ' are not in the same directory"
171+ )
172+
173+ if output_dir is None or len (output_dir ) == 0 :
174+ output_dir = os .path .dirname (resolved_full_file )
175+ print (f"Using { output_dir } as output directory" )
176+ return output_dir
177+ else :
178+ # copy the files to the output directory
179+ shutil .copy (resolved_full_file , output_dir )
180+ shutil .copy (resolved_config_file , output_dir )
181+ return output_dir
182+
27183
28184def gen_processing_models (processor : Union [str , object ],
29185 pre_kwargs : dict = None ,
30186 post_kwargs : dict = None ,
31187 opset : int = None ,
188+ schema_v2 : bool = False ,
32189 ** kwargs ):
33190 """
34191 Generate the pre- and post-processing ONNX model, basing on the name or HF class.
@@ -47,6 +204,9 @@ def gen_processing_models(processor: Union[str, object],
47204 Keyword arguments for generating the post-processing model
48205 opset: int
49206 the target opset version of the model
207+ schema_v2: bool
208+ the flag for using embedded tokenizer files; this option leverages the blob-loading functionality
209+ which loads HF tokenizers from memory rather than using the tokenizer files in HF JSON format.
50210 kwargs:
51211 The additional arguments for generating models
52212
@@ -58,39 +218,89 @@ def gen_processing_models(processor: Union[str, object],
58218 if pre_kwargs is None and post_kwargs is None :
59219 raise ValueError (
60220 "Either pre_kwargs or post_kwargs should be provided. None means no processing graph output." )
61- if isinstance (processor , str ):
221+
222+ # If true, we get the tokenizer JSON files by either downloading from cache or using HuggingFace AutoTokenizer
223+ # to convert them, and then create an ONNX model with the JSON files as strings in the model attributes (attrs).
224+ if schema_v2 :
225+ model_name = processor if isinstance (processor , str ) else type (processor ).__name__
226+
227+ converted_tokenizer = {"Baichuan2" , "chatglm" }
228+ need_convert = False
229+ for token in converted_tokenizer :
230+ if model_name .find (token ) != - 1 :
231+ need_convert = True
232+ break
233+
234+ if need_convert :
235+ model_dir = convert_tokenizer (model_name )
236+ validate_tokenizer (model_name , None )
237+ else :
238+ model_dir = download_tokenizer (model_name , None )
239+
240+ # Load the content of tokenizer.json into a string
241+ with open (f"{ model_dir } /tokenizer.json" , "r" , encoding = "utf-8" ) as f :
242+ tokenizer_vocab = f .read ()
243+
244+ # Load the content of tokenizer_config.json into a string
245+ with open (f"{ model_dir } /tokenizer_config.json" , "r" , encoding = "utf-8" ) as f :
246+ tokenizer_config = f .read ()
247+
248+ # Create an ONNX model with these JSON file strings in attrs
62249 g_pre , g_post = (None , None )
63- if pre_kwargs :
64- g_pre = SingleOpGraph .build_graph (processor , ** pre_kwargs )
65- if post_kwargs :
250+ if pre_kwargs is not None :
251+ # Add tokenizer_vocab and tokenizer_config to the kwargs
252+ # so they are added to attrs in build_graph
253+ pre_kwargs ['tokenizer_vocab' ] = tokenizer_vocab
254+ pre_kwargs ['tokenizer_config' ] = tokenizer_config
255+ g_pre = SingleOpGraph .build_graph ("HfJsonTokenizer" , ** pre_kwargs )
256+ if post_kwargs is not None :
66257 if pre_kwargs is None :
67258 cls_name = processor
68259 else :
69260 if processor not in _PRE_POST_PAIR :
70261 raise RuntimeError (
71262 f"Cannot locate the post processing operator name from { processor } " )
72263 cls_name = _PRE_POST_PAIR [processor ]
264+ # Add tokenizer_vocab and tokenizer_config to the kwargs
265+ # so they are added to attrs in build_graph
266+ post_kwargs ['tokenizer_vocab' ] = tokenizer_vocab
267+ post_kwargs ['tokenizer_config' ] = tokenizer_config
73268 g_post = SingleOpGraph .build_graph (cls_name , ** post_kwargs )
74269 return make_onnx_model (g_pre ) if g_pre else None , make_onnx_model (g_post ) if g_post else None
75-
76- cls_name = type (processor ).__name__
77- if cls_name == "WhisperProcessor" :
78- if WhisperDataProcGraph is None :
79- raise ValueError (
80- "The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above" )
81- _converter = WhisperDataProcGraph (processor , opset = opset , ** kwargs )
82- pre_m = _converter .pre_processing (
83- ** pre_kwargs ) if pre_kwargs is not None else None
84- post_m = _converter .post_processing (
85- ** post_kwargs ) if post_kwargs is not None else None
86- return pre_m , post_m
87- elif HFTokenizerOnnxGraph .is_supported (processor ):
88- _converter = HFTokenizerOnnxGraph (processor )
89- pre_g = _converter .pre_processing (
90- ** pre_kwargs ) if pre_kwargs is not None else None
91- post_g = _converter .post_processing (
92- ** post_kwargs ) if post_kwargs is not None else None
93- return make_onnx_model (pre_g ) if pre_g else None , \
94- make_onnx_model (post_g ) if post_g else None
95270 else :
96- raise ValueError (f"Unsupported processor/tokenizer: { cls_name } " )
271+ if isinstance (processor , str ):
272+ g_pre , g_post = (None , None )
273+ if pre_kwargs :
274+ g_pre = SingleOpGraph .build_graph (processor , ** pre_kwargs )
275+ if post_kwargs :
276+ if pre_kwargs is None :
277+ cls_name = processor
278+ else :
279+ if processor not in _PRE_POST_PAIR :
280+ raise RuntimeError (
281+ f"Cannot locate the post processing operator name from { processor } " )
282+ cls_name = _PRE_POST_PAIR [processor ]
283+ g_post = SingleOpGraph .build_graph (cls_name , ** post_kwargs )
284+ return make_onnx_model (g_pre ) if g_pre else None , make_onnx_model (g_post ) if g_post else None
285+
286+ cls_name = type (processor ).__name__
287+ if cls_name == "WhisperProcessor" :
288+ if WhisperDataProcGraph is None :
289+ raise ValueError (
290+ "The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above" )
291+ _converter = WhisperDataProcGraph (processor , opset = opset , ** kwargs )
292+ pre_m = _converter .pre_processing (
293+ ** pre_kwargs ) if pre_kwargs is not None else None
294+ post_m = _converter .post_processing (
295+ ** post_kwargs ) if post_kwargs is not None else None
296+ return pre_m , post_m
297+ elif HFTokenizerOnnxGraph .is_supported (processor ):
298+ _converter = HFTokenizerOnnxGraph (processor )
299+ pre_g = _converter .pre_processing (
300+ ** pre_kwargs ) if pre_kwargs is not None else None
301+ post_g = _converter .post_processing (
302+ ** post_kwargs ) if post_kwargs is not None else None
303+ return make_onnx_model (pre_g ) if pre_g else None , \
304+ make_onnx_model (post_g ) if post_g else None
305+ else :
306+ raise ValueError (f"Unsupported processor/tokenizer: { cls_name } " )
0 commit comments