Skip to content

Commit

Permalink
chore: Add TRT runner via onnx (#2503)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 authored Nov 30, 2023
1 parent 9b88e92 commit 9ed1849
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 30 deletions.
4 changes: 2 additions & 2 deletions tools/perf/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ python hub.py

batch_sizes=(1 2 4 8 16 32 64 128 256)
large_model_batch_sizes=(1 2 4 8 16 32 64)
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor")
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor")
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "tensorrt")
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "tensorrt")


# Benchmark VGG16 model
Expand Down
52 changes: 24 additions & 28 deletions tools/perf/perf_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# Importing supported Backends
import torch
import torch_tensorrt as torchtrt
from utils import (
BENCHMARK_MODELS,
parse_backends,
Expand All @@ -23,8 +24,6 @@
precision_to_dtype,
)

import torch_tensorrt as torchtrt

WARMUP_ITER = 10
results = []

Expand Down Expand Up @@ -294,29 +293,30 @@ def run_tensorrt(
input_tensors,
params,
precision,
is_trt_engine=False,
batch_size=1,
):
engine = None

# If the model file is a TensorRT engine then directly deserialize and run inference
# else convert the torch module to a TensorRT engine first and then run inference
if not is_trt_engine:
compile_settings = {
"inputs": input_tensors,
"enabled_precisions": {precision_to_dtype(precision)},
"truncate_long_and_double": params.get("truncate", False),
}

print("Converting method to TensorRT engine...")
with torch.no_grad(), torchtrt.logging.errors():
model = torchtrt.ts.convert_method_to_trt_engine(
model, "forward", **compile_settings
)

# Export an ONNX model and convert to TRT
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
success = parser.parse_from_file("./tmp.onnx")
if not success:
raise ValueError("ONNX conversion failed")

config = builder.create_builder_config()
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
start_compile = time.time_ns()
serialized_engine = builder.build_serialized_network(network, config)
end_compile = time.time_ns()
compile_time_s = (end_compile - start_compile) / 1e9
# Deserialize the TensorRT engine
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(model)
with trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(serialized_engine)

print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
iters = params.get("iterations", 20)
Expand Down Expand Up @@ -351,7 +351,7 @@ def run_tensorrt(
meas_time = end_time - start_time
timings.append(meas_time)

recordStats("TensorRT", timings, precision, batch_size)
recordStats("TensorRT", timings, precision, batch_size, compile_time_s)


# Deploys inference run for different backend configurations
Expand Down Expand Up @@ -427,11 +427,10 @@ def run(
)
elif backend == "tensorrt":
run_tensorrt(
model,
model_torch,
input_tensors,
params,
precision,
is_trt_engine,
batch_size,
)
elif backend == "dynamo":
Expand All @@ -440,9 +439,6 @@ def run(
elif backend == "torch_compile":
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)

elif backend == "torch_compile":
run_torch_compile(model_torch, input_tensors, params, precision, batch_size)

elif backend == "inductor":
run_inductor(model_torch, input_tensors, params, precision, batch_size)

Expand Down
2 changes: 2 additions & 0 deletions tools/perf/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
numpy
argparse
pyyaml
onnx
transformers==4.33.2
diffusers==0.21.4
pandas==2.0.1
timm==0.9.8

0 comments on commit 9ed1849

Please sign in to comment.