Skip to content

Commit 4d706e0

Browse files
authored
minor (#351)
1 parent a29f2af commit 4d706e0

File tree

1 file changed

+96
-80
lines changed

1 file changed

+96
-80
lines changed

_scripts/export_qwen25_vl_visual.py

Lines changed: 96 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,20 @@ def main(
9898
make_zip: bool = False,
9999
output_folder: str = "dump_models",
100100
existing_onnx: str | None = None,
101+
part: str = "visual",
101102
):
102103
prefix = simplify_model_id_for_a_filename(model_id)
103104
if "QWEN25ATTENTION" in os.environ:
104105
prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}"
105106
basename = os.path.join(
106-
output_folder, f"model.{prefix}.visual.{device}.{dtype}.{exporter}"
107+
output_folder, f"model.{prefix}.{part}.{device}.{dtype}.{exporter}"
107108
)
108109
filename = f"{basename}.onnx"
109110
stat_file = f"{basename}.stats"
110111

111112
print("------------------------------------------------------------------")
112113
print(
113-
f"-- {model_id} {device} {dtype} {exporter} {pretrained} "
114+
f"-- {model_id} {part} {device} {dtype} {exporter} {pretrained} "
114115
f"{second_input} {make_zip} {output_folder} {prefix}"
115116
)
116117
print("------------------------------------------------------------------")
@@ -186,47 +187,75 @@ def _config_reduction(config, task):
186187
print(f"-- model.device={model.device}")
187188
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
188189
print(f"-- processor={type(processor)}")
189-
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
190-
print(f"-- model_to_export={type(model_to_export)}")
191-
192-
print("-- ############")
193-
print("-- INPUT/OUTPUT")
194-
print("-- ############")
195-
196-
input_filename = os.path.join(output_folder, f"inputs.{prefix}.visual.{device}.{dtype}.pt")
197-
if os.path.exists(input_filename):
198-
print(f"-- restore inputs from {input_filename!r}")
199-
data = torch.load(input_filename)
200-
export_inputs = data["export_inputs"]
201-
other_inputs = data["other_inputs"]
202-
else:
203-
export_inputs = dict(
204-
hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
205-
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
190+
191+
if part == "visual":
192+
193+
class VisualPart(torch.nn.Module):
194+
def __init__(self, model):
195+
super().__init__()
196+
self.model = model
197+
198+
def forward(self, pixel_values, image_grid_thw):
199+
return model.get_image_features(pixel_values, image_grid_thw)
200+
201+
assert hasattr(
202+
model, "get_image_features"
203+
), f"get_image_features not found in class {type(model)}"
204+
model_to_export = VisualPart(model)
205+
206+
print(f"-- part={part!r}")
207+
print(f"-- model_to_export={type(model_to_export)}")
208+
209+
print("-- ############")
210+
print("-- INPUT/OUTPUT")
211+
print("-- ############")
212+
213+
input_filename = os.path.join(
214+
output_folder, f"inputs.{prefix}.{part}.{device}.{dtype}.pt"
206215
)
207-
other_inputs = []
208-
if second_input:
209-
other_inputs = [
210-
dict(
211-
hidden_states=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
212-
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
213-
),
214-
dict(
215-
hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
216-
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
217-
),
218-
dict(
219-
hidden_states=torch.randn((14308, 1176), dtype=torch_dtype).to(device),
220-
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
221-
),
222-
dict(
223-
hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
224-
grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device),
225-
),
226-
]
227-
data = dict(export_inputs=export_inputs, other_inputs=other_inputs)
228-
print(f"-- dump inputs into {input_filename!r}")
229-
torch.save(data, input_filename)
216+
if os.path.exists(input_filename):
217+
print(f"-- restore inputs from {input_filename!r}")
218+
data = torch.load(input_filename)
219+
export_inputs = data["export_inputs"]
220+
other_inputs = data["other_inputs"]
221+
else:
222+
export_inputs = dict(
223+
pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
224+
image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
225+
)
226+
other_inputs = []
227+
if second_input:
228+
other_inputs = [
229+
dict(
230+
pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device),
231+
image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(
232+
device
233+
),
234+
),
235+
dict(
236+
pixel_values=torch.rand((1292, 1176), dtype=torch_dtype).to(device),
237+
image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(
238+
device
239+
),
240+
),
241+
dict(
242+
pixel_values=torch.randn((14308, 1176), dtype=torch_dtype).to(device),
243+
image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(
244+
device
245+
),
246+
),
247+
dict(
248+
pixel_values=torch.rand((14308, 1176), dtype=torch_dtype).to(device),
249+
image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(
250+
device
251+
),
252+
),
253+
]
254+
data = dict(export_inputs=export_inputs, other_inputs=other_inputs)
255+
print(f"-- dump inputs into {input_filename!r}")
256+
torch.save(data, input_filename)
257+
else:
258+
raise NotImplementedError(f"part={part!r} not implemnted yet")
230259

231260
print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}")
232261
print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}")
@@ -290,8 +319,8 @@ def compute_expected():
290319
print("-- ######")
291320

292321
dynamic_shapes = dict(
293-
hidden_states={0: "hidden_width", 1: "hidden_height"},
294-
grid_thw={}, # {0: "n_images"}, # TODO: fix
322+
pixel_values={0: "hidden_width", 1: "hidden_height"},
323+
image_grid_thw={}, # {0: "n_images"}, # TODO: fix
295324
)
296325

297326
begin = time.perf_counter()
@@ -336,15 +365,11 @@ def compute_expected():
336365
remove_inplace_body_last_input_output_type_for_loop(filename)
337366
print("-- done.")
338367

339-
with open(stat_file, "w") as f:
368+
###############
369+
# check for discrepancies
370+
###############
340371

341-
def _rename(k):
342-
if rename_inputs:
343-
if k == "hidden_states":
344-
return "pixel_values"
345-
if k == "grid_thw":
346-
return "image_grid_thw"
347-
return k
372+
with open(stat_file, "w") as f:
348373

349374
def fprint(s):
350375
print(s)
@@ -355,23 +380,23 @@ def fprint(s):
355380
if device == "cpu":
356381
providers = providers[1:]
357382
fprint(f"-- checking discrepancies with providers={providers!r}")
383+
fprint(f"-- filename={filename!r}")
358384
sess = onnxruntime.InferenceSession(filename, providers=providers)
359-
rename_inputs = sess.get_inputs()[0].name != "hidden_states"
360385

361386
fprint(
362387
f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
363388
)
364389
fprint(
365390
f"-- export_expected {string_type(export_expected, with_shape=True, with_device=True)}"
366391
)
367-
feeds = {_rename(k): v.detach().cpu().numpy() for k, v in export_inputs.items()}
392+
feeds = {k: v.detach().cpu().numpy() for k, v in export_inputs.items()}
368393
small = sess.run(None, feeds)
369394
diff = max_diff(export_expected, small[0], hist=[0.1, 0.01])
370395
fprint(f"-- discrepancies={diff}")
371396

372397
if second_input:
373398
feeds = [
374-
{_rename(k): v.detach().cpu().numpy() for k, v in inputs.items()}
399+
{k: v.detach().cpu().numpy() for k, v in inputs.items()}
375400
for inputs in other_inputs
376401
]
377402
fprint("")
@@ -399,6 +424,7 @@ def fprint(s):
399424

400425
info = {
401426
"model_id": model_id,
427+
"part": part,
402428
"device": device,
403429
"dtype": dtype,
404430
"exporter": exporter,
@@ -432,23 +458,16 @@ def fprint(s):
432458
"timestamp",
433459
"model_id",
434460
"pretrained",
461+
"part",
435462
"device",
436463
"dtype",
437464
"attention",
438465
"opset",
439466
]
467+
index = [*first[1:], "exporter"]
440468
df = df[[*first, *[c for c in df.columns if c not in set(first)]]]
441469
df.to_excel(statistics + ".xlsx")
442470

443-
index = [
444-
"model_id",
445-
"pretrained",
446-
"device",
447-
"dtype",
448-
"attention",
449-
"opset",
450-
"exporter",
451-
]
452471
values = [
453472
"abs",
454473
"%>0.1",
@@ -458,26 +477,16 @@ def fprint(s):
458477
"latency_torch",
459478
"latency_ort_n",
460479
]
461-
stat = (
462-
df[[*index, *values]]
463-
.groupby(index, dropna=False)
464-
.agg(
465-
{
466-
**{c: "max" for c in values if c != "speedup"},
467-
"speedup": "min",
468-
}
469-
)
470-
)
480+
agg = {
481+
**{c: "max" for c in values if c != "speedup"},
482+
"speedup": "min",
483+
}
484+
stat = df[[*index, *values]].groupby(index, dropna=False).agg(agg)
471485
stat.to_excel(statistics + ".agg.xlsx")
472486
stat = (
473487
df[df.exporter != "custom"][[*index, *values]]
474488
.groupby(index, dropna=False)
475-
.agg(
476-
{
477-
**{c: "max" for c in values if c != "speedup"},
478-
"speedup": "min",
479-
}
480-
)
489+
.agg(agg)
481490
)
482491
stat.to_excel(statistics + ".agg.onnx-dynamo.xlsx")
483492

@@ -543,6 +552,12 @@ def get_parser() -> ArgumentParser:
543552
default="",
544553
help="If an onnx file exists, only measures the disrepancies.",
545554
)
555+
parser.add_argument(
556+
"-p",
557+
"--part",
558+
default="visual",
559+
help="part of the model to export",
560+
)
546561
return parser
547562

548563

@@ -559,4 +574,5 @@ def get_parser() -> ArgumentParser:
559574
make_zip=args.zip,
560575
output_folder=args.output_folder,
561576
existing_onnx=args.existing_onnx,
577+
part=args.part,
562578
)

0 commit comments

Comments
 (0)