@@ -49,7 +49,6 @@ def get_model(args):
4949 torch.nn.Module: The loaded and configured model ready for inference,
5050 moved to CUDA device with the specified precision
5151 """
52-
5352 with torch .no_grad ():
5453 model = (
5554 AutoModelForCausalLM .from_pretrained (
@@ -112,23 +111,7 @@ def compile_torchtrt(model, input_ids, args):
112111 else :
113112 enabled_precisions = {torch .float32 }
114113
115- qformat = "_q_" + args .qformat if args .qformat else ""
116-
117- logging_dir = f"./{ args .model } _{ args .precision } { qformat } "
118- # with torch_tensorrt.logging.debug() if args.debug else nullcontext():
119- with (
120- torch_tensorrt .dynamo .Debugger (
121- "debug" ,
122- logging_dir = logging_dir ,
123- # capture_fx_graph_after=["constant_fold"],
124- # save_engine_profile=True,
125- # profile_format="trex",
126- engine_builder_monitor = False ,
127- # save_layer_info=True,
128- )
129- if args .debug
130- else nullcontext ()
131- ):
114+ with torch_tensorrt .logging .debug () if args .debug else nullcontext ():
132115 trt_model = torch_tensorrt .dynamo .compile (
133116 ep ,
134117 inputs = [input_ids , position_ids ],
@@ -151,14 +134,12 @@ def print_outputs(backend_name, gen_tokens, tokenizer):
151134 """
152135 Print the generated tokens from the model.
153136 """
154- out = tokenizer .decode (gen_tokens [0 ], skip_special_tokens = True )
155137 print (f"========= { backend_name } =========" )
156138 print (
157139 f"{ backend_name } model generated text: " ,
158- out ,
140+ tokenizer . decode ( gen_tokens [ 0 ], skip_special_tokens = True ) ,
159141 )
160142 print ("===================================" )
161- return out
162143
163144
164145def measure_perf (trt_model , input_signature , backend_name ):
@@ -260,13 +241,13 @@ def measure_perf(trt_model, input_signature, backend_name):
260241 )
261242 arg_parser .add_argument (
262243 "--qformat" ,
263- help = ("Apply quantization format. Options: fp8 (default: None)" ),
244+ help = ("Apply quantization format. Options: fp8, nvfp4 (default: None)" ),
264245 default = None ,
265246 )
266247 arg_parser .add_argument (
267248 "--pre_quantized" ,
268249 action = "store_true" ,
269- help = "Use pre-quantized model weights (default: False)" ,
250+ help = "Use pre-quantized hf model weights (default: False)" ,
270251 )
271252 args = arg_parser .parse_args ()
272253
@@ -300,6 +281,7 @@ def measure_perf(trt_model, input_signature, backend_name):
300281 pyt_gen_tokens = None
301282 pyt_timings = None
302283 pyt_stats = None
284+
303285 if args .qformat != None :
304286 model = quantize_model (model , args , tokenizer )
305287 if args .enable_pytorch_run :
@@ -380,43 +362,19 @@ def measure_perf(trt_model, input_signature, backend_name):
380362 batch_size = args .batch_size ,
381363 compile_time_s = None ,
382364 )
383- match_result = "N/A"
384- torch_out = "N/A"
385- model_name = args .model .replace ("/" , "_" )
386- qformat = args .qformat if args .qformat else "no_quant"
387365
388366 if not args .benchmark :
389367 if args .enable_pytorch_run :
390- torch_out = print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
368+ print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
391369
392- trt_out = print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
370+ print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
393371
394372 if args .enable_pytorch_run :
395373 print (
396374 f"PyTorch and TensorRT outputs match: { torch .equal (pyt_gen_tokens , trt_gen_tokens )} "
397375 )
398- match_result = str (torch .equal (pyt_gen_tokens , trt_gen_tokens ))
399- out_json_file = f"{ model_name } _{ qformat } _match.json"
400- result = {}
401- args_dict = vars (args )
402- result ["args" ] = args_dict
403- result ["match" ] = match_result
404- result ["torch_out" ] = torch_out
405- result ["trt_out" ] = trt_out
406- with open (os .path .join ("result" , out_json_file ), "w" ) as f :
407- json .dump (result , f , indent = 4 )
408- print (f"Results saved to { out_json_file } " )
376+
409377 if args .benchmark :
410- result = {}
411- args_dict = vars (args )
412-
413- result ["args" ] = args_dict
414- result ["pyt_stats" ] = pyt_stats if args .enable_pytorch_run else None
415- result ["trt_stats" ] = trt_stats if args .benchmark else None
416- out_json_file = f"{ model_name } _{ qformat } _benchmark.json"
417- with open (os .path .join ("result" , out_json_file ), "w" ) as f :
418- json .dump (result , f , indent = 4 )
419- print (f"Results saved to { out_json_file } " )
420378 if args .enable_pytorch_run :
421379 print ("=========PyTorch PERFORMANCE============ \n " )
422380 print (pyt_stats )
0 commit comments