Skip to content

Commit a29f2af

Browse files
authored
tiny modifications (#350)
1 parent 3fb23bd commit a29f2af

File tree

1 file changed

+75
-48
lines changed

1 file changed

+75
-48
lines changed

_scripts/export_qwen25_vl_visual.py

Lines changed: 75 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def main(
9797
second_input: bool = True,
9898
make_zip: bool = False,
9999
output_folder: str = "dump_models",
100+
existing_onnx: str | None = None,
100101
):
101102
prefix = simplify_model_id_for_a_filename(model_id)
102103
if "QWEN25ATTENTION" in os.environ:
@@ -115,7 +116,7 @@ def main(
115116
print("------------------------------------------------------------------")
116117
print(f"-- export in {filename!r}")
117118

118-
if os.path.exists(stat_file):
119+
if os.path.exists(stat_file) and not existing_onnx:
119120
print(f"-- skipping because {stat_file!r} already exists")
120121
return
121122

@@ -278,55 +279,73 @@ def compute_expected():
278279
compute_expected() if not os.environ.get("STOPAT", "") else (None, None)
279280
)
280281

281-
print("-- ######")
282-
print("-- EXPORT")
283-
print("-- ######")
282+
if existing_onnx and os.path.exists(existing_onnx):
283+
exporter = existing_onnx
284+
filename = existing_onnx
285+
export_duration = None
286+
target_opset = None
287+
else:
288+
print("-- ######")
289+
print("-- EXPORT")
290+
print("-- ######")
284291

285-
dynamic_shapes = dict(
286-
hidden_states={0: "hidden_width", 1: "hidden_height"},
287-
grid_thw={}, # {0: "n_images"}, # TODO: fix
288-
)
292+
dynamic_shapes = dict(
293+
hidden_states={0: "hidden_width", 1: "hidden_height"},
294+
grid_thw={}, # {0: "n_images"}, # TODO: fix
295+
)
289296

290-
begin = time.perf_counter()
291-
292-
target_opset = 22
293-
if exporter == "onnx-dynamo" and device == "cuda" and "QWEN25ATTENTION" not in os.environ:
294-
os.environ["QWEN25ATTENTION"] = "PACKED"
295-
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
296-
target_opset = 23
297-
298-
with torch_export_patches(
299-
patch_torch=False,
300-
patch_sympy=False,
301-
patch_transformers=True,
302-
verbose=1,
303-
stop_if_static=2,
304-
):
305-
if export_expected is None:
306-
export_expected, other_expected, durations = compute_expected()
307-
to_onnx(
308-
model_to_export,
309-
kwargs=export_inputs,
310-
dynamic_shapes=dynamic_shapes,
311-
filename=filename,
312-
exporter=exporter,
297+
begin = time.perf_counter()
298+
299+
target_opset = 22
300+
if (
301+
exporter == "onnx-dynamo"
302+
and device == "cuda"
303+
and "QWEN25ATTENTION" not in os.environ
304+
):
305+
os.environ["QWEN25ATTENTION"] = "PACKED"
306+
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
307+
target_opset = 23
308+
309+
with torch_export_patches(
310+
patch_torch=False,
311+
patch_sympy=False,
312+
patch_transformers=True,
313313
verbose=1,
314-
save_ep=None,
315-
target_opset=target_opset,
316-
optimize=True,
317-
onnx_plugs=PLUGS,
318-
)
319-
export_duration = time.perf_counter() - begin
314+
stop_if_static=2,
315+
):
316+
if export_expected is None:
317+
export_expected, other_expected, durations = compute_expected()
318+
to_onnx(
319+
model_to_export,
320+
kwargs=export_inputs,
321+
dynamic_shapes=dynamic_shapes,
322+
filename=filename,
323+
exporter=exporter,
324+
verbose=1,
325+
save_ep=None,
326+
target_opset=target_opset,
327+
optimize=True,
328+
onnx_plugs=PLUGS,
329+
)
330+
export_duration = time.perf_counter() - begin
320331

321-
if exporter == "onnx-dynamo":
322-
# onnx-dynamo fails at producing function body with sequences as input / output.
323-
# They are replaced by tensor type one step in the model.
324-
print("-- remove_body_last_input_output_for_loop")
325-
remove_inplace_body_last_input_output_type_for_loop(filename)
326-
print("-- done.")
332+
if exporter == "onnx-dynamo":
333+
# onnx-dynamo fails at producing function body with sequences as input / output.
334+
# They are replaced by tensor type one step in the model.
335+
print("-- remove_body_last_input_output_for_loop")
336+
remove_inplace_body_last_input_output_type_for_loop(filename)
337+
print("-- done.")
327338

328339
with open(stat_file, "w") as f:
329340

341+
def _rename(k):
342+
if rename_inputs:
343+
if k == "hidden_states":
344+
return "pixel_values"
345+
if k == "grid_thw":
346+
return "image_grid_thw"
347+
return k
348+
330349
def fprint(s):
331350
print(s)
332351
f.write(f"{s}\n")
@@ -337,21 +356,22 @@ def fprint(s):
337356
providers = providers[1:]
338357
fprint(f"-- checking discrepancies with providers={providers!r}")
339358
sess = onnxruntime.InferenceSession(filename, providers=providers)
359+
rename_inputs = sess.get_inputs()[0].name != "hidden_states"
340360

341361
fprint(
342362
f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
343363
)
344364
fprint(
345365
f"-- export_expected {string_type(export_expected, with_shape=True, with_device=True)}"
346366
)
347-
feeds = {k: v.detach().cpu().numpy() for k, v in export_inputs.items()}
367+
feeds = {_rename(k): v.detach().cpu().numpy() for k, v in export_inputs.items()}
348368
small = sess.run(None, feeds)
349369
diff = max_diff(export_expected, small[0], hist=[0.1, 0.01])
350370
fprint(f"-- discrepancies={diff}")
351371

352372
if second_input:
353373
feeds = [
354-
{k: v.detach().cpu().numpy() for k, v in inputs.items()}
374+
{_rename(k): v.detach().cpu().numpy() for k, v in inputs.items()}
355375
for inputs in other_inputs
356376
]
357377
fprint("")
@@ -440,7 +460,7 @@ def fprint(s):
440460
]
441461
stat = (
442462
df[[*index, *values]]
443-
.groupby(index)
463+
.groupby(index, dropna=False)
444464
.agg(
445465
{
446466
**{c: "max" for c in values if c != "speedup"},
@@ -450,8 +470,8 @@ def fprint(s):
450470
)
451471
stat.to_excel(statistics + ".agg.xlsx")
452472
stat = (
453-
df[df.exporter == "onnx-dynamo"][[*index, *values]]
454-
.groupby(index)
473+
df[df.exporter != "custom"][[*index, *values]]
474+
.groupby(index, dropna=False)
455475
.agg(
456476
{
457477
**{c: "max" for c in values if c != "speedup"},
@@ -517,6 +537,12 @@ def get_parser() -> ArgumentParser:
517537
help="Folders where to put the results.",
518538
action=BooleanOptionalAction,
519539
)
540+
parser.add_argument(
541+
"-x",
542+
"--existing-onnx",
543+
default="",
544+
help="If an onnx file exists, only measures the disrepancies.",
545+
)
520546
return parser
521547

522548

@@ -532,4 +558,5 @@ def get_parser() -> ArgumentParser:
532558
second_input=args.second_input,
533559
make_zip=args.zip,
534560
output_folder=args.output_folder,
561+
existing_onnx=args.existing_onnx,
535562
)

0 commit comments

Comments
 (0)