@@ -98,19 +98,20 @@ def main(
9898 make_zip : bool = False ,
9999 output_folder : str = "dump_models" ,
100100 existing_onnx : str | None = None ,
101+ part : str = "visual" ,
101102):
102103 prefix = simplify_model_id_for_a_filename (model_id )
103104 if "QWEN25ATTENTION" in os .environ :
104105 prefix = f"{ prefix } .{ os .environ ['QWEN25ATTENTION' ]} "
105106 basename = os .path .join (
106- output_folder , f"model.{ prefix } .visual .{ device } .{ dtype } .{ exporter } "
107+ output_folder , f"model.{ prefix } .{ part } .{ device } .{ dtype } .{ exporter } "
107108 )
108109 filename = f"{ basename } .onnx"
109110 stat_file = f"{ basename } .stats"
110111
111112 print ("------------------------------------------------------------------" )
112113 print (
113- f"-- { model_id } { device } { dtype } { exporter } { pretrained } "
114+ f"-- { model_id } { part } { device } { dtype } { exporter } { pretrained } "
114115 f"{ second_input } { make_zip } { output_folder } { prefix } "
115116 )
116117 print ("------------------------------------------------------------------" )
@@ -186,47 +187,75 @@ def _config_reduction(config, task):
186187 print (f"-- model.device={ model .device } " )
187188 processor = AutoProcessor .from_pretrained (model_id , use_fast = True )
188189 print (f"-- processor={ type (processor )} " )
189- model_to_export = model .visual if hasattr (model , "visual" ) else model .model .visual
190- print (f"-- model_to_export={ type (model_to_export )} " )
191-
192- print ("-- ############" )
193- print ("-- INPUT/OUTPUT" )
194- print ("-- ############" )
195-
196- input_filename = os .path .join (output_folder , f"inputs.{ prefix } .visual.{ device } .{ dtype } .pt" )
197- if os .path .exists (input_filename ):
198- print (f"-- restore inputs from { input_filename !r} " )
199- data = torch .load (input_filename )
200- export_inputs = data ["export_inputs" ]
201- other_inputs = data ["other_inputs" ]
202- else :
203- export_inputs = dict (
204- hidden_states = torch .randn ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
205- grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (device ),
190+
191+ if part == "visual" :
192+
193+ class VisualPart (torch .nn .Module ):
194+ def __init__ (self , model ):
195+ super ().__init__ ()
196+ self .model = model
197+
198+ def forward (self , pixel_values , image_grid_thw ):
199+ return model .get_image_features (pixel_values , image_grid_thw )
200+
201+ assert hasattr (
202+ model , "get_image_features"
203+ ), f"get_image_features not found in class { type (model )} "
204+ model_to_export = VisualPart (model )
205+
206+ print (f"-- part={ part !r} " )
207+ print (f"-- model_to_export={ type (model_to_export )} " )
208+
209+ print ("-- ############" )
210+ print ("-- INPUT/OUTPUT" )
211+ print ("-- ############" )
212+
213+ input_filename = os .path .join (
214+ output_folder , f"inputs.{ prefix } .{ part } .{ device } .{ dtype } .pt"
206215 )
207- other_inputs = []
208- if second_input :
209- other_inputs = [
210- dict (
211- hidden_states = torch .randn ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
212- grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (device ),
213- ),
214- dict (
215- hidden_states = torch .rand ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
216- grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (device ),
217- ),
218- dict (
219- hidden_states = torch .randn ((14308 , 1176 ), dtype = torch_dtype ).to (device ),
220- grid_thw = torch .tensor ([[1 , 98 , 146 ]], dtype = torch .int64 ).to (device ),
221- ),
222- dict (
223- hidden_states = torch .rand ((14308 , 1176 ), dtype = torch_dtype ).to (device ),
224- grid_thw = torch .tensor ([[1 , 98 , 146 ]], dtype = torch .int64 ).to (device ),
225- ),
226- ]
227- data = dict (export_inputs = export_inputs , other_inputs = other_inputs )
228- print (f"-- dump inputs into { input_filename !r} " )
229- torch .save (data , input_filename )
216+ if os .path .exists (input_filename ):
217+ print (f"-- restore inputs from { input_filename !r} " )
218+ data = torch .load (input_filename )
219+ export_inputs = data ["export_inputs" ]
220+ other_inputs = data ["other_inputs" ]
221+ else :
222+ export_inputs = dict (
223+ pixel_values = torch .randn ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
224+ image_grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (device ),
225+ )
226+ other_inputs = []
227+ if second_input :
228+ other_inputs = [
229+ dict (
230+ pixel_values = torch .randn ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
231+ image_grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (
232+ device
233+ ),
234+ ),
235+ dict (
236+ pixel_values = torch .rand ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
237+ image_grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (
238+ device
239+ ),
240+ ),
241+ dict (
242+ pixel_values = torch .randn ((14308 , 1176 ), dtype = torch_dtype ).to (device ),
243+ image_grid_thw = torch .tensor ([[1 , 98 , 146 ]], dtype = torch .int64 ).to (
244+ device
245+ ),
246+ ),
247+ dict (
248+ pixel_values = torch .rand ((14308 , 1176 ), dtype = torch_dtype ).to (device ),
249+ image_grid_thw = torch .tensor ([[1 , 98 , 146 ]], dtype = torch .int64 ).to (
250+ device
251+ ),
252+ ),
253+ ]
254+ data = dict (export_inputs = export_inputs , other_inputs = other_inputs )
255+ print (f"-- dump inputs into { input_filename !r} " )
256+ torch .save (data , input_filename )
257+ else :
258+ raise NotImplementedError (f"part={ part !r} not implemnted yet" )
230259
231260 print (f"-- export_inputs={ string_type (export_inputs , with_shape = True , with_device = True )} " )
232261 print (f"-- other_inputs={ string_type (other_inputs , with_shape = True , with_device = True )} " )
@@ -290,8 +319,8 @@ def compute_expected():
290319 print ("-- ######" )
291320
292321 dynamic_shapes = dict (
293- hidden_states = {0 : "hidden_width" , 1 : "hidden_height" },
294- grid_thw = {}, # {0: "n_images"}, # TODO: fix
322+ pixel_values = {0 : "hidden_width" , 1 : "hidden_height" },
323+ image_grid_thw = {}, # {0: "n_images"}, # TODO: fix
295324 )
296325
297326 begin = time .perf_counter ()
@@ -336,15 +365,11 @@ def compute_expected():
336365 remove_inplace_body_last_input_output_type_for_loop (filename )
337366 print ("-- done." )
338367
339- with open (stat_file , "w" ) as f :
368+ ###############
369+ # check for discrepancies
370+ ###############
340371
341- def _rename (k ):
342- if rename_inputs :
343- if k == "hidden_states" :
344- return "pixel_values"
345- if k == "grid_thw" :
346- return "image_grid_thw"
347- return k
372+ with open (stat_file , "w" ) as f :
348373
349374 def fprint (s ):
350375 print (s )
@@ -355,23 +380,23 @@ def fprint(s):
355380 if device == "cpu" :
356381 providers = providers [1 :]
357382 fprint (f"-- checking discrepancies with providers={ providers !r} " )
383+ fprint (f"-- filename={ filename !r} " )
358384 sess = onnxruntime .InferenceSession (filename , providers = providers )
359- rename_inputs = sess .get_inputs ()[0 ].name != "hidden_states"
360385
361386 fprint (
362387 f"-- export_inputs { string_type (export_inputs , with_shape = True , with_device = True )} "
363388 )
364389 fprint (
365390 f"-- export_expected { string_type (export_expected , with_shape = True , with_device = True )} "
366391 )
367- feeds = {_rename ( k ) : v .detach ().cpu ().numpy () for k , v in export_inputs .items ()}
392+ feeds = {k : v .detach ().cpu ().numpy () for k , v in export_inputs .items ()}
368393 small = sess .run (None , feeds )
369394 diff = max_diff (export_expected , small [0 ], hist = [0.1 , 0.01 ])
370395 fprint (f"-- discrepancies={ diff } " )
371396
372397 if second_input :
373398 feeds = [
374- {_rename ( k ) : v .detach ().cpu ().numpy () for k , v in inputs .items ()}
399+ {k : v .detach ().cpu ().numpy () for k , v in inputs .items ()}
375400 for inputs in other_inputs
376401 ]
377402 fprint ("" )
@@ -399,6 +424,7 @@ def fprint(s):
399424
400425 info = {
401426 "model_id" : model_id ,
427+ "part" : part ,
402428 "device" : device ,
403429 "dtype" : dtype ,
404430 "exporter" : exporter ,
@@ -432,23 +458,16 @@ def fprint(s):
432458 "timestamp" ,
433459 "model_id" ,
434460 "pretrained" ,
461+ "part" ,
435462 "device" ,
436463 "dtype" ,
437464 "attention" ,
438465 "opset" ,
439466 ]
467+ index = [* first [1 :], "exporter" ]
440468 df = df [[* first , * [c for c in df .columns if c not in set (first )]]]
441469 df .to_excel (statistics + ".xlsx" )
442470
443- index = [
444- "model_id" ,
445- "pretrained" ,
446- "device" ,
447- "dtype" ,
448- "attention" ,
449- "opset" ,
450- "exporter" ,
451- ]
452471 values = [
453472 "abs" ,
454473 "%>0.1" ,
@@ -458,26 +477,16 @@ def fprint(s):
458477 "latency_torch" ,
459478 "latency_ort_n" ,
460479 ]
461- stat = (
462- df [[* index , * values ]]
463- .groupby (index , dropna = False )
464- .agg (
465- {
466- ** {c : "max" for c in values if c != "speedup" },
467- "speedup" : "min" ,
468- }
469- )
470- )
480+ agg = {
481+ ** {c : "max" for c in values if c != "speedup" },
482+ "speedup" : "min" ,
483+ }
484+ stat = df [[* index , * values ]].groupby (index , dropna = False ).agg (agg )
471485 stat .to_excel (statistics + ".agg.xlsx" )
472486 stat = (
473487 df [df .exporter != "custom" ][[* index , * values ]]
474488 .groupby (index , dropna = False )
475- .agg (
476- {
477- ** {c : "max" for c in values if c != "speedup" },
478- "speedup" : "min" ,
479- }
480- )
489+ .agg (agg )
481490 )
482491 stat .to_excel (statistics + ".agg.onnx-dynamo.xlsx" )
483492
@@ -543,6 +552,12 @@ def get_parser() -> ArgumentParser:
543552 default = "" ,
544553 help = "If an onnx file exists, only measures the disrepancies." ,
545554 )
555+ parser .add_argument (
556+ "-p" ,
557+ "--part" ,
558+ default = "visual" ,
559+ help = "part of the model to export" ,
560+ )
546561 return parser
547562
548563
@@ -559,4 +574,5 @@ def get_parser() -> ArgumentParser:
559574 make_zip = args .zip ,
560575 output_folder = args .output_folder ,
561576 existing_onnx = args .existing_onnx ,
577+ part = args .part ,
562578 )
0 commit comments