Skip to content

Commit b43c4c2

Browse files
authored
feat: add --use_python_runtime and --enable_cuda_graph args to the perf run script (#3397)
1 parent b0464ca commit b43c4c2

File tree

1 file changed

+96
-25
lines changed

1 file changed

+96
-25
lines changed

tools/perf/perf_run.py

+96-25
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,17 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size):
175175
"inputs": input_tensors,
176176
"enabled_precisions": {precision_to_dtype(precision)},
177177
"truncate_long_and_double": params.get("truncate", False),
178+
"use_python_runtime": params.get("use_python_runtime", False),
178179
}
179180

180181
if precision == "int8":
181182
compile_settings.update({"calib": params.get("calibration_cache")})
182183

184+
if params.get("enable_cuda_graph", False):
185+
logging.warning(
186+
f"Torchscript backend doesn't support CUDA Graphs. `--enable_cuda_graph` will be ignored."
187+
)
188+
183189
start_compile = timeit.default_timer()
184190
model = torchtrt.compile(model, ir="ts", **compile_settings)
185191
end_compile = timeit.default_timer()
@@ -217,19 +223,34 @@ def run_hf_dynamo(model, input_tensors, params, precision, batch_size):
217223
inputs=input_tensors,
218224
enabled_precisions={precision_to_dtype(precision)},
219225
truncate_double=params.get("truncate", False),
226+
use_python_runtime=params.get("use_python_runtime", False),
220227
)
221228
end_compile = timeit.default_timer()
222229
compile_time_s = end_compile - start_compile
223-
record_llm_perf(
224-
trt_model,
225-
"Dynamo",
226-
input_tensors,
227-
precision,
228-
osl,
229-
batch_size,
230-
iters,
231-
compile_time_s,
232-
)
230+
231+
if params.get("enable_cuda_graph", False):
232+
with torchtrt.runtime.enable_cudagraphs(trt_model) as cudagraphs_module:
233+
record_llm_perf(
234+
cudagraphs_module,
235+
"Dynamo",
236+
input_tensors,
237+
precision,
238+
osl,
239+
batch_size,
240+
iters,
241+
compile_time_s,
242+
)
243+
else:
244+
record_llm_perf(
245+
trt_model,
246+
"Dynamo",
247+
input_tensors,
248+
precision,
249+
osl,
250+
batch_size,
251+
iters,
252+
compile_time_s,
253+
)
233254

234255

235256
@run_with_try_except
@@ -262,14 +283,27 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
262283
),
263284
cache_built_engines=params.get("cache_built_engines", False),
264285
reuse_cached_engines=params.get("reuse_cached_engines", False),
286+
use_python_runtime=params.get("use_python_runtime", False),
265287
)
266288
end_compile = timeit.default_timer()
267289
compile_time_s = end_compile - start_compile
268290
iters = params.get("iterations", 20)
269291

270-
record_perf(
271-
model, "Dynamo", input_tensors, precision, iters, batch_size, compile_time_s
272-
)
292+
if params.get("enable_cuda_graph", False):
293+
with torchtrt.runtime.enable_cudagraphs(model) as cudagraphs_module:
294+
record_perf(
295+
cudagraphs_module,
296+
"Dynamo",
297+
input_tensors,
298+
precision,
299+
iters,
300+
batch_size,
301+
compile_time_s,
302+
)
303+
else:
304+
record_perf(
305+
model, "Dynamo", input_tensors, precision, iters, batch_size, compile_time_s
306+
)
273307

274308

275309
@run_with_try_except
@@ -292,6 +326,7 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size):
292326
"enabled_precisions": {precision_to_dtype(precision)},
293327
"truncate": params.get("truncate", False),
294328
"min_block_size": params.get("min_block_size", 1),
329+
"use_python_runtime": params.get("use_python_runtime", False),
295330
}
296331
start_compile = timeit.default_timer()
297332
model = torch.compile(model, backend="tensorrt", dynamic=None, options=compile_spec)
@@ -300,15 +335,27 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size):
300335
compile_time_s = end_compile - start_compile
301336
iters = params.get("iterations", 20)
302337

303-
record_perf(
304-
model,
305-
"torch_compile",
306-
input_tensors,
307-
precision,
308-
iters,
309-
batch_size,
310-
compile_time_s,
311-
)
338+
if params.get("enable_cuda_graph", False):
339+
with torchtrt.runtime.enable_cudagraphs(model) as cudagraphs_module:
340+
record_perf(
341+
cudagraphs_module,
342+
"torch_compile",
343+
input_tensors,
344+
precision,
345+
iters,
346+
batch_size,
347+
compile_time_s,
348+
)
349+
else:
350+
record_perf(
351+
model,
352+
"torch_compile",
353+
input_tensors,
354+
precision,
355+
iters,
356+
batch_size,
357+
compile_time_s,
358+
)
312359

313360

314361
@run_with_try_except
@@ -320,9 +367,13 @@ def run_hf_inductor(model, input_tensors, params, precision, batch_size):
320367
# Mark dynamic shapes for input sequence
321368
input_seq = input_tensors[0]
322369
torch._dynamo.mark_dynamic(input_seq, 1, min=1, max=osl)
370+
mode = "max-autotune"
371+
if params.get("enable_cuda_graph", False):
372+
mode = "reduce-overhead"
373+
323374
start_compile = timeit.default_timer()
324375
# Compile the model
325-
model = torch.compile(model, backend="inductor", dynamic=None, mode="max-autotune")
376+
model = torch.compile(model, backend="inductor", dynamic=None, mode=mode)
326377
model(input_seq)
327378
end_compile = timeit.default_timer()
328379
compile_time_s = end_compile - start_compile
@@ -356,15 +407,25 @@ def run_inductor(model, input_tensors, params, precision, batch_size):
356407
if params["is_text_llm"]:
357408
return run_hf_inductor(model, input_tensors, params, precision, batch_size)
358409

410+
mode = "max-autotune"
411+
if params.get("enable_cuda_graph", False):
412+
mode = "reduce-overhead"
413+
359414
start_compile = timeit.default_timer()
360-
model = torch.compile(model, backend="inductor", dynamic=None, mode="max-autotune")
415+
model = torch.compile(model, backend="inductor", dynamic=None, mode=mode)
361416
model(*input_tensors)
362417
end_compile = timeit.default_timer()
363418
compile_time_s = end_compile - start_compile
364419
iters = params.get("iterations", 20)
365420

366421
record_perf(
367-
model, "inductor", input_tensors, precision, iters, batch_size, compile_time_s
422+
model,
423+
"inductor",
424+
input_tensors,
425+
precision,
426+
iters,
427+
batch_size,
428+
compile_time_s,
368429
)
369430

370431

@@ -587,6 +648,16 @@ def run(
587648
action="store_true",
588649
help="Boolean flag to determine if the user provided model is a TRT engine or not",
589650
)
651+
arg_parser.add_argument(
652+
"--use_python_runtime",
653+
action="store_true",
654+
help="Whether to use Python runtime or not. Using C++ runtime by default",
655+
)
656+
arg_parser.add_argument(
657+
"--enable_cuda_graph",
658+
action="store_true",
659+
help="Whether to enable CUDA Graph. It is not used by default",
660+
)
590661
arg_parser.add_argument(
591662
"--report",
592663
type=str,

0 commit comments

Comments
 (0)