@@ -97,6 +97,7 @@ def main(
9797 second_input : bool = True ,
9898 make_zip : bool = False ,
9999 output_folder : str = "dump_models" ,
100+ existing_onnx : str | None = None ,
100101):
101102 prefix = simplify_model_id_for_a_filename (model_id )
102103 if "QWEN25ATTENTION" in os .environ :
@@ -115,7 +116,7 @@ def main(
115116 print ("------------------------------------------------------------------" )
116117 print (f"-- export in { filename !r} " )
117118
118- if os .path .exists (stat_file ):
119+ if os .path .exists (stat_file ) and not existing_onnx :
119120 print (f"-- skipping because { stat_file !r} already exists" )
120121 return
121122
@@ -278,55 +279,73 @@ def compute_expected():
278279 compute_expected () if not os .environ .get ("STOPAT" , "" ) else (None , None )
279280 )
280281
281- print ("-- ######" )
282- print ("-- EXPORT" )
283- print ("-- ######" )
282+ if existing_onnx and os .path .exists (existing_onnx ):
283+ exporter = existing_onnx
284+ filename = existing_onnx
285+ export_duration = None
286+ target_opset = None
287+ else :
288+ print ("-- ######" )
289+ print ("-- EXPORT" )
290+ print ("-- ######" )
284291
285- dynamic_shapes = dict (
286- hidden_states = {0 : "hidden_width" , 1 : "hidden_height" },
287- grid_thw = {}, # {0: "n_images"}, # TODO: fix
288- )
292+ dynamic_shapes = dict (
293+ hidden_states = {0 : "hidden_width" , 1 : "hidden_height" },
294+ grid_thw = {}, # {0: "n_images"}, # TODO: fix
295+ )
289296
290- begin = time .perf_counter ()
291-
292- target_opset = 22
293- if exporter == "onnx-dynamo" and device == "cuda" and "QWEN25ATTENTION" not in os .environ :
294- os .environ ["QWEN25ATTENTION" ] = "PACKED"
295- elif "QWEN25ATTENTION" in os .environ and os .environ ["QWEN25ATTENTION" ] == "LOOPA23" :
296- target_opset = 23
297-
298- with torch_export_patches (
299- patch_torch = False ,
300- patch_sympy = False ,
301- patch_transformers = True ,
302- verbose = 1 ,
303- stop_if_static = 2 ,
304- ):
305- if export_expected is None :
306- export_expected , other_expected , durations = compute_expected ()
307- to_onnx (
308- model_to_export ,
309- kwargs = export_inputs ,
310- dynamic_shapes = dynamic_shapes ,
311- filename = filename ,
312- exporter = exporter ,
297+ begin = time .perf_counter ()
298+
299+ target_opset = 22
300+ if (
301+ exporter == "onnx-dynamo"
302+ and device == "cuda"
303+ and "QWEN25ATTENTION" not in os .environ
304+ ):
305+ os .environ ["QWEN25ATTENTION" ] = "PACKED"
306+ elif "QWEN25ATTENTION" in os .environ and os .environ ["QWEN25ATTENTION" ] == "LOOPA23" :
307+ target_opset = 23
308+
309+ with torch_export_patches (
310+ patch_torch = False ,
311+ patch_sympy = False ,
312+ patch_transformers = True ,
313313 verbose = 1 ,
314- save_ep = None ,
315- target_opset = target_opset ,
316- optimize = True ,
317- onnx_plugs = PLUGS ,
318- )
319- export_duration = time .perf_counter () - begin
314+ stop_if_static = 2 ,
315+ ):
316+ if export_expected is None :
317+ export_expected , other_expected , durations = compute_expected ()
318+ to_onnx (
319+ model_to_export ,
320+ kwargs = export_inputs ,
321+ dynamic_shapes = dynamic_shapes ,
322+ filename = filename ,
323+ exporter = exporter ,
324+ verbose = 1 ,
325+ save_ep = None ,
326+ target_opset = target_opset ,
327+ optimize = True ,
328+ onnx_plugs = PLUGS ,
329+ )
330+ export_duration = time .perf_counter () - begin
320331
321- if exporter == "onnx-dynamo" :
322- # onnx-dynamo fails at producing function body with sequences as input / output.
323- # They are replaced by tensor type one step in the model.
324- print ("-- remove_body_last_input_output_for_loop" )
325- remove_inplace_body_last_input_output_type_for_loop (filename )
326- print ("-- done." )
332+ if exporter == "onnx-dynamo" :
333+ # onnx-dynamo fails at producing function body with sequences as input / output.
334+ # They are replaced by tensor type one step in the model.
335+ print ("-- remove_body_last_input_output_for_loop" )
336+ remove_inplace_body_last_input_output_type_for_loop (filename )
337+ print ("-- done." )
327338
328339 with open (stat_file , "w" ) as f :
329340
341+ def _rename (k ):
342+ if rename_inputs :
343+ if k == "hidden_states" :
344+ return "pixel_values"
345+ if k == "grid_thw" :
346+ return "image_grid_thw"
347+ return k
348+
330349 def fprint (s ):
331350 print (s )
332351 f .write (f"{ s } \n " )
@@ -337,21 +356,22 @@ def fprint(s):
337356 providers = providers [1 :]
338357 fprint (f"-- checking discrepancies with providers={ providers !r} " )
339358 sess = onnxruntime .InferenceSession (filename , providers = providers )
359+ rename_inputs = sess .get_inputs ()[0 ].name != "hidden_states"
340360
341361 fprint (
342362 f"-- export_inputs { string_type (export_inputs , with_shape = True , with_device = True )} "
343363 )
344364 fprint (
345365 f"-- export_expected { string_type (export_expected , with_shape = True , with_device = True )} "
346366 )
347- feeds = {k : v .detach ().cpu ().numpy () for k , v in export_inputs .items ()}
367+ feeds = {_rename ( k ) : v .detach ().cpu ().numpy () for k , v in export_inputs .items ()}
348368 small = sess .run (None , feeds )
349369 diff = max_diff (export_expected , small [0 ], hist = [0.1 , 0.01 ])
350370 fprint (f"-- discrepancies={ diff } " )
351371
352372 if second_input :
353373 feeds = [
354- {k : v .detach ().cpu ().numpy () for k , v in inputs .items ()}
374+ {_rename ( k ) : v .detach ().cpu ().numpy () for k , v in inputs .items ()}
355375 for inputs in other_inputs
356376 ]
357377 fprint ("" )
@@ -440,7 +460,7 @@ def fprint(s):
440460 ]
441461 stat = (
442462 df [[* index , * values ]]
443- .groupby (index )
463+ .groupby (index , dropna = False )
444464 .agg (
445465 {
446466 ** {c : "max" for c in values if c != "speedup" },
@@ -450,8 +470,8 @@ def fprint(s):
450470 )
451471 stat .to_excel (statistics + ".agg.xlsx" )
452472 stat = (
453- df [df .exporter == "onnx-dynamo " ][[* index , * values ]]
454- .groupby (index )
473+ df [df .exporter != "custom " ][[* index , * values ]]
474+ .groupby (index , dropna = False )
455475 .agg (
456476 {
457477 ** {c : "max" for c in values if c != "speedup" },
@@ -517,6 +537,12 @@ def get_parser() -> ArgumentParser:
517537 help = "Folders where to put the results." ,
518538 action = BooleanOptionalAction ,
519539 )
540+ parser .add_argument (
541+ "-x" ,
542+ "--existing-onnx" ,
543+ default = "" ,
544+ help = "If an onnx file exists, only measures the disrepancies." ,
545+ )
520546 return parser
521547
522548
@@ -532,4 +558,5 @@ def get_parser() -> ArgumentParser:
532558 second_input = args .second_input ,
533559 make_zip = args .zip ,
534560 output_folder = args .output_folder ,
561+ existing_onnx = args .existing_onnx ,
535562 )
0 commit comments