Skip to content

Commit c3705cb

Browse files
committed
Changed _export to be the default one
1 parent c497756 commit c3705cb

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

tools/llm/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,37 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
1616
In the case of guard failures due to some PyTorch kernel implements, we also
1717
try to re-export the graph by expressing them as runtime assert nodes
1818
"""
19+
print(
20+
f"Exporting model with min_seq_len={min_seq_len} and max_seq_len={max_seq_len}"
21+
)
1922
with torch.no_grad():
2023
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
2124
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
2225
position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
2326
try:
24-
print("Trying to export the model using torch.export.export()..")
27+
print("Trying to export the model using torch.export._trace._export()..")
2528
# strict=False only enables aotautograd tracing and excludes dynamo.
26-
ep = torch.export.export(
29+
ep = torch.export._trace._export(
2730
model,
2831
args=(inputs,),
2932
kwargs={"position_ids": position_ids},
3033
dynamic_shapes=({1: seq_len}, {1: seq_len}),
3134
strict=False,
35+
allow_complex_guards_as_runtime_asserts=True,
3236
)
37+
3338
except:
3439
print(
35-
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
40+
"Trying torch.export.export to trace the graph since torch.export._trace._export() failed"
3641
)
3742
# This API is used to express the constraint violation guards as asserts in the graph.
38-
ep = torch.export._trace._export(
43+
44+
ep = torch.export.export(
3945
model,
4046
args=(inputs,),
4147
kwargs={"position_ids": position_ids},
4248
dynamic_shapes=({1: seq_len}, {1: seq_len}),
4349
strict=False,
44-
prefer_deferred_runtime_asserts_over_guards=True,
4550
)
4651

4752
return ep
@@ -223,6 +228,7 @@ def time_generate(
223228
"""
224229
timings = []
225230
for _ in range(iterations):
231+
print(f"Iteration {_} of {iterations}")
226232
start_time = timeit.default_timer()
227233
_ = generate_fn(model, inputs, output_seq_length, eos_token_id)
228234
torch.cuda.synchronize()

0 commit comments

Comments
 (0)