Open
Description
您好,我在把MMCVModulatedDeformConv2d的onnx格式转为tensorrt时候出错了,麻烦MM的大佬们帮忙看一下,万分感谢!
代码如下:
import torch
import onnx
from mmcv.tensorrt import (TRTWrapper, onnx2trt, save_trt_engine,
is_tensorrt_plugin_loaded)
# assert is_tensorrt_plugin_loaded(), 'Requires to complie TensorRT plugins in mmcv'
onnx_file = '/xx_backbone.onnx'
trt_file = 'sample.trt'
onnx_model = onnx.load(onnx_file)
# Model input
inputs = torch.rand(1, 3, 288, 288).cuda()
# Model input shape info
opt_shape_dict = {
'input': [list(inputs.shape),
list(inputs.shape),
list(inputs.shape)]
}
# Create TensorRT engine
max_workspace_size = 1 << 30
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
max_workspace_size=max_workspace_size)
# Save TensorRT engine
save_trt_engine(trt_engine, trt_file)
# Run inference with TensorRT
trt_model = TRTWrapper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': inputs})
output = trt_outputs['output']
Tensorrt的版本跟mmcv文档中要求的是一致的:TensorRT-7.2.1.6