Skip to content

[Suggestion] Why tensorrt backend uses trtexec instead of tensorrt python interface? #367

Open
@LeiWang1999

Description

@LeiWang1999

From my point of view, use python interface we can insert cudaprofilestart() and cudaprofilestop() to better prof our program, because if we use trtexec, superbench will start anothor thread to execute and nvprof can not correctly prof the real command, and, directly profile trtexec will prof the compilation progress and runtime progress, in most of the case, we only need the last one.

tensorrt python interface example:

import tensorrt as trt
import common
import time
import pycuda.driver as cuda
import torch
import os

TRT_LOGGER = trt.Logger()


def inference(context, test_data):
    inputs, outputs, bindings, stream = common.allocate_buffers(context.engine)
    result = []
    inputs[0].host = test_data

    _, elapsed_time = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

    return result, elapsed_time

# This function builds an engine from a Onnx model.
def build_engine(model_file, batch_size=32):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as trt_config:

        # Attention that, builder should be set to 1 because of the implementation of allocate_buffer
        builder.max_batch_size = 1
        # builder.max_workspace_size = common.GiB(1)
        trt_config.max_workspace_size = common.GiB(4)

        
        # Parse onnx model
        with open(model_file, 'rb') as model:
            if not parser.parse(model.read()):
                print ('ERROR: Failed to parse the ONNX file.')
                for error in range(parser.num_errors):
                    print (parser.get_error(error))
                return None


        # This design may not be correct if output more than one
        """
        for i in range(network.num_layers):
            layer = network.get_layer(i)
            layer.precision = trt.int8
            layer.set_output_type(0, trt.int8)
        """


        # network.mark_output(model_tensors.find(ModelData.OUTPUT_NAME))
        # Build engine and do int8 calibration.
        # engine = builder.build_cuda_engine(network)
        engine = builder.build_engine(network, trt_config)
        return engine

onnx_path = "/workspace/v-leiwang3/benchmark/nnfusion_models/resnet50.float32.1.onnx"
dummy_input = torch.rand(1, 3, 224, 224).numpy()

engine = build_engine(onnx_path)
context = engine.create_execution_context()

# warmup
for i in range(5):
    _, time = inference(context, dummy_input)

# iteration
time_set = []
for i in range(100):
    _, time = inference(context, dummy_input)
    time_set.append(time)

print(f'average time: {sum(time_set)/len(time_set)* 1000} ms')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions