99
1010import argparse
1111import copy
12+ import json
1213import os
1314import timeit
1415from contextlib import nullcontext
2122from torchtrt_ext import register_sdpa
2223from transformers import AutoModelForCausalLM , AutoTokenizer
2324from utils import (
25+ convert_linear_to_tensorrt_quantized ,
2426 export_llm ,
2527 generate ,
2628 generate_with_static_cache ,
29+ quantize_model ,
2730 record_stats ,
2831 time_generate ,
2932)
@@ -48,6 +51,7 @@ def get_model(args):
4851 torch.nn.Module: The loaded and configured model ready for inference,
4952 moved to CUDA device with the specified precision
5053 """
54+
5155 with torch .no_grad ():
5256 model = (
5357 AutoModelForCausalLM .from_pretrained (
@@ -58,6 +62,8 @@ def get_model(args):
5862 .eval ()
5963 .cuda ()
6064 )
65+ if args .pre_quantized :
66+ model = convert_linear_to_tensorrt_quantized (model , args .model )
6167
6268 if args .precision == "FP16" :
6369 model = model .to (torch .float16 )
@@ -106,7 +112,23 @@ def compile_torchtrt(model, input_ids, args):
106112 else :
107113 enabled_precisions = {torch .float32 }
108114
109- with torch_tensorrt .logging .debug () if args .debug else nullcontext ():
115+ qformat = "_q_" + args .qformat if args .qformat else ""
116+
117+ logging_dir = f"./{ args .model } _{ args .precision } { qformat } "
118+ # with torch_tensorrt.logging.debug() if args.debug else nullcontext():
119+ with (
120+ torch_tensorrt .dynamo .Debugger (
121+ "debug" ,
122+ logging_dir = logging_dir ,
123+ # capture_fx_graph_after=["constant_fold"],
124+ # save_engine_profile=True,
125+ # profile_format="trex",
126+ engine_builder_monitor = False ,
127+ # save_layer_info=True,
128+ )
129+ if args .debug
130+ else nullcontext ()
131+ ):
110132 trt_model = torch_tensorrt .dynamo .compile (
111133 ep ,
112134 inputs = [input_ids , position_ids ],
@@ -129,12 +151,14 @@ def print_outputs(backend_name, gen_tokens, tokenizer):
129151 """
130152 Print the generated tokens from the model.
131153 """
154+ out = tokenizer .decode (gen_tokens [0 ], skip_special_tokens = True )
132155 print (f"========= { backend_name } =========" )
133156 print (
134157 f"{ backend_name } model generated text: " ,
135- tokenizer . decode ( gen_tokens [ 0 ], skip_special_tokens = True ) ,
158+ out ,
136159 )
137160 print ("===================================" )
161+ return out
138162
139163
140164def measure_perf (trt_model , input_signature , backend_name ):
@@ -234,13 +258,24 @@ def measure_perf(trt_model, input_signature, backend_name):
234258 arg_parser .add_argument (
235259 "--benchmark" , action = "store_true" , help = "Enable benchmark (default: False)"
236260 )
237-
261+ arg_parser .add_argument (
262+ "--qformat" ,
263+ help = ("Apply quantization format. Options: fp8 (default: None)" ),
264+ default = None ,
265+ )
266+ arg_parser .add_argument (
267+ "--pre_quantized" ,
268+ action = "store_true" ,
269+ help = "Use pre-quantized model weights (default: False)" ,
270+ )
238271 args = arg_parser .parse_args ()
239272 with torch .inference_mode ():
240273 model = get_model (args )
241274
242275 tokenizer = AutoTokenizer .from_pretrained (args .tokenizer or args .model )
243-
276+ # Set pad token
277+ if tokenizer .pad_token is None :
278+ tokenizer .pad_token = tokenizer .eos_token
244279 # Prepare input for benchmarking or evaluation
245280 if args .benchmark :
246281 input_ids = torch .randint (
@@ -257,7 +292,8 @@ def measure_perf(trt_model, input_signature, backend_name):
257292 pyt_gen_tokens = None
258293 pyt_timings = None
259294 pyt_stats = None
260-
295+ if args .qformat != None :
296+ model = quantize_model (model , args , tokenizer )
261297 if args .enable_pytorch_run :
262298 pyt_gen_tokens = generate (
263299 model , input_ids .clone (), MAX_OUTPUT_SEQ_LENGTH , tokenizer .eos_token_id
@@ -336,19 +372,41 @@ def measure_perf(trt_model, input_signature, backend_name):
336372 batch_size = args .batch_size ,
337373 compile_time_s = None ,
338374 )
375+ match_result = "N/A"
376+ torch_out = "N/A"
377+ model_name = args .model .replace ("/" , "_" )
378+ qformat = args .qformat if args .qformat else "no_quant"
339379
340380 if not args .benchmark :
341381 if args .enable_pytorch_run :
342- print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
382+ torch_out = print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
343383
344- print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
384+ trt_out = print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
345385
346386 if args .enable_pytorch_run :
347387 print (
348388 f"PyTorch and TensorRT outputs match: { torch .equal (pyt_gen_tokens , trt_gen_tokens )} "
349389 )
350-
390+ match_result = str (torch .equal (pyt_gen_tokens , trt_gen_tokens ))
391+ out_json_file = f"{ model_name } _{ qformat } _match.json"
392+ result = {}
393+ result ["match" ] = match_result
394+ result ["torch_out" ] = torch_out
395+ result ["trt_out" ] = trt_out
396+ with open (os .path .join ("result" , out_json_file ), "w" ) as f :
397+ json .dump (result , f , indent = 4 )
398+ print (f"Results saved to { out_json_file } " )
351399 if args .benchmark :
400+ result = {}
401+ args_dict = vars (args )
402+
403+ result ["args" ] = args_dict
404+ result ["pyt_stats" ] = pyt_stats if args .enable_pytorch_run else None
405+ result ["trt_stats" ] = trt_stats if args .benchmark else None
406+ out_json_file = f"{ model_name } _{ qformat } _benchmark.json"
407+ with open (os .path .join ("result" , out_json_file ), "w" ) as f :
408+ json .dump (result , f , indent = 4 )
409+ print (f"Results saved to { out_json_file } " )
352410 if args .enable_pytorch_run :
353411 print ("=========PyTorch PERFORMANCE============ \n " )
354412 print (pyt_stats )
0 commit comments