@@ -16,32 +16,37 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
1616 In the case of guard failures due to some PyTorch kernel implements, we also
1717 try to re-export the graph by expressing them as runtime assert nodes
1818 """
19+ print (
20+ f"Exporting model with min_seq_len={ min_seq_len } and max_seq_len={ max_seq_len } "
21+ )
1922 with torch .no_grad ():
2023 # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
2124 seq_len = torch .export .Dim ("seq_len" , min = min_seq_len , max = max_seq_len )
2225 position_ids = torch .arange (inputs .shape [1 ]).unsqueeze (0 ).to (inputs .device )
2326 try :
24- print ("Trying to export the model using torch.export.export ().." )
27+ print ("Trying to export the model using torch.export._trace._export ().." )
2528 # strict=False only enables aotautograd tracing and excludes dynamo.
26- ep = torch .export .export (
29+ ep = torch .export ._trace . _export (
2730 model ,
2831 args = (inputs ,),
2932 kwargs = {"position_ids" : position_ids },
3033 dynamic_shapes = ({1 : seq_len }, {1 : seq_len }),
3134 strict = False ,
35+ allow_complex_guards_as_runtime_asserts = True ,
3236 )
37+
3338 except :
3439 print (
35- "Trying torch.export._trace._export to trace the graph since torch.export.export () failed"
40+ "Trying torch.export.export to trace the graph since torch.export._trace._export () failed"
3641 )
3742 # This API is used to express the constraint violation guards as asserts in the graph.
38- ep = torch .export ._trace ._export (
43+
44+ ep = torch .export .export (
3945 model ,
4046 args = (inputs ,),
4147 kwargs = {"position_ids" : position_ids },
4248 dynamic_shapes = ({1 : seq_len }, {1 : seq_len }),
4349 strict = False ,
44- prefer_deferred_runtime_asserts_over_guards = True ,
4550 )
4651
4752 return ep
@@ -223,6 +228,7 @@ def time_generate(
223228 """
224229 timings = []
225230 for _ in range (iterations ):
231+ print (f"Iteration { _ } of { iterations } " )
226232 start_time = timeit .default_timer ()
227233 _ = generate_fn (model , inputs , output_seq_length , eos_token_id )
228234 torch .cuda .synchronize ()
0 commit comments