Skip to content

Commit c2d90c4

Browse files
committed
add trt python demo
1 parent 4dcf170 commit c2d90c4

File tree

2 files changed

+212
-19
lines changed

2 files changed

+212
-19
lines changed

tensorrt/README.md

+55-19
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
11

2-
### My platform
32

4-
* ubuntu 18.04
5-
* nvidia Tesla T4 gpu, driver newer than 440
6-
* cuda 10.2, cudnn 8
7-
* cmake 3.10.2
8-
* opencv built from source
9-
* tensorrt 7.2.3.4
3+
## Deploy with Tensorrt
104

11-
12-
### Export model to onnx
13-
I export the model like this:
5+
Firstly, We should export our trained model to onnx model:
146
```
7+
$ cd BiSeNet/
158
$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx
169
```
1710

18-
**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here.
11+
**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here.
12+
13+
Then we can use either c++ or python to compile the model and run inference.
14+
15+
16+
### Using C++
17+
18+
#### My platform
1919

20-
### Build with source code
20+
* ubuntu 18.04
21+
* nvidia Tesla T4 gpu, driver newer than 450.80
22+
* cuda 11.3, cudnn 8
23+
* cmake 3.17.1
24+
* opencv built from source
25+
* tensorrt 8.2.5.1
26+
27+
28+
29+
#### Build with source code
2130
Just use the standard cmake build method:
2231
```
2332
mkdir -p tensorrt/build
@@ -28,7 +37,7 @@ make
2837
This would generate a `./segment` in the `tensorrt/build` directory.
2938

3039

31-
### Convert onnx to tensorrt model
40+
#### Convert onnx to tensorrt model
3241
If you can successfully compile the source code, you can parse the onnx model to tensorrt model like this:
3342
```
3443
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt
@@ -37,30 +46,57 @@ If your gpu support acceleration with fp16 inferenece, you can add a `--fp16` op
3746
```
3847
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16
3948
```
40-
Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the above command.
49+
Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the args in above command.
4150

4251

43-
### Infer with one single image
52+
#### Infer with one single image
4453
Run inference like this:
4554
```
4655
$ ./segment run /path/to/saved_model.trt /path/to/input/image.jpg /path/to/saved_img.jpg
4756
```
4857

49-
### Test speed
58+
59+
#### Test speed
5060
The speed depends on the specific gpu platform you are working on, you can test the fps on your gpu like this:
5161
```
5262
$ ./segment test /path/to/saved_model.trt
5363
```
5464

5565

56-
## Tips:
66+
#### Tips:
5767
1. ~Since tensorrt 7.0.0 cannot parse well the `bilinear interpolation` op exported from pytorch, I replace them with pytorch `nn.PixelShuffle`, which would bring some performance overhead(more flops and parameters), and make inference a bit slower. Also due to the `nn.PixelShuffle` op, you **must** export the onnx model with input size to be *n* times of 32.~
58-
If you are using 7.2.3.4, you should not have problem with `interpolate` anymore.
68+
If you are using 7.2.3.4 or newer versions, you should not have problem with `interpolate` anymore.
5969

6070
2. ~There would be some problem for tensorrt 7.0.0 to parse the `nn.AvgPool2d` op from pytorch with onnx opset11. So I use opset10 to export the model.~
61-
Likewise, you do not need to worry about this anymore with 7.2.3.4.
71+
Likewise, you do not need to worry about this anymore with version newer than 7.2.3.4.
6272

6373
3. The speed(fps) is tested on a single nvidia Tesla T4 gpu with `batchsize=1` and `cropsize=(1024,2048)`. Please note that T4 gpu is almost 2 times slower than 2080ti, you should evaluate the speed considering your own platform and cropsize. Also note that the performance would be affected if your gpu is concurrently working on other tasks. Please make sure no other program is running on your gpu when you test the speed.
6474

6575
4. On my platform, after compiling with tensorrt, the model size of bisenetv1 is 29Mb(fp16) and 128Mb(fp32), and the size of bisenetv2 is 16Mb(fp16) and 42Mb(fp32). However, the fps of bisenetv1 is 68(fp16) and 23(fp32), while the fps of bisenetv2 is 59(fp16) and 21(fp32). It is obvious that bisenetv2 has fewer parameters than bisenetv1, but the speed is otherwise. I am not sure whether it is because tensorrt has worse optimization strategy in some ops used in bisenetv2(such as depthwise convolution) or because of the limitation of the gpu on different ops. Please tell me if you have better idea on this.
6676

77+
78+
### Using python
79+
80+
You can also use python script to compile and run inference of your model.
81+
82+
83+
#### Compile model to onnx
84+
85+
With this command:
86+
```
87+
$ cd BiSeNet/tensorrt
88+
$ python segment.py compile --onnx /path/to/model.onnx --savepth ./model.trt --quant fp16/fp32
89+
```
90+
91+
This will compile onnx model into tensorrt serialized engine, save save to `./model.trt`.
92+
93+
94+
#### inference with Tensorrt
95+
96+
Run Inference like this:
97+
```
98+
$ python segment.py run --mdpth ./model.trt --impth ../example.png --outpth ./res.png
99+
```
100+
101+
This will use the tensorrt model compiled above, and run inference with the example image.
102+

tensorrt/segment.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
2+
import os
3+
import os.path as osp
4+
import cv2
5+
import numpy as np
6+
import logging
7+
import argparse
8+
9+
import tensorrt as trt
10+
import pycuda.driver as cuda
11+
import pycuda.autoinit
12+
13+
14+
parser = argparse.ArgumentParser()
15+
subparsers = parser.add_subparsers(dest="command")
16+
compile_parser = subparsers.add_parser('compile')
17+
compile_parser.add_argument('--onnx')
18+
compile_parser.add_argument('--quant', default='fp32')
19+
compile_parser.add_argument('--savepth', default='./model.trt')
20+
run_parser = subparsers.add_parser('run')
21+
run_parser.add_argument('--mdpth')
22+
run_parser.add_argument('--impth')
23+
run_parser.add_argument('--outpth', default='./res.png')
24+
args = parser.parse_args()
25+
26+
27+
np.random.seed(123)
28+
in_datatype = trt.nptype(trt.float32)
29+
out_datatype = trt.nptype(trt.int32)
30+
palette = np.random.randint(0, 256, (256, 3)).astype(np.uint8)
31+
32+
ctx = pycuda.autoinit.context
33+
trt.init_libnvinfer_plugins(None, "")
34+
TRT_LOGGER = trt.Logger()
35+
36+
37+
38+
def get_image(impth, size):
39+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None]
40+
var = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None]
41+
iH, iW = size[0], size[1]
42+
img = cv2.imread(impth)[:, :, ::-1]
43+
orgH, orgW, _ = img.shape
44+
img = cv2.resize(img, (iW, iH)).astype(np.float32)
45+
img = img.transpose(2, 0, 1) / 255.
46+
img = (img - mean) / var
47+
return img, (orgH, orgW)
48+
49+
50+
51+
def allocate_buffers(engine):
52+
h_input = cuda.pagelocked_empty(
53+
trt.volume(engine.get_binding_shape(0)), dtype=in_datatype)
54+
print(engine.get_binding_shape(0))
55+
d_input = cuda.mem_alloc(h_input.nbytes)
56+
h_outputs, d_outputs = [], []
57+
n_outs = 1
58+
for i in range(n_outs):
59+
h_output = cuda.pagelocked_empty(
60+
trt.volume(engine.get_binding_shape(i+1)),
61+
dtype=out_datatype)
62+
d_output = cuda.mem_alloc(h_output.nbytes)
63+
h_outputs.append(h_output)
64+
d_outputs.append(d_output)
65+
stream = cuda.Stream()
66+
return (
67+
stream,
68+
h_input,
69+
d_input,
70+
h_outputs,
71+
d_outputs,
72+
)
73+
74+
75+
def build_engine_from_onnx(onnx_file_path):
76+
engine = None ## add this to avoid return deleted engine
77+
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
78+
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime:
79+
80+
# Parse model file
81+
print(f'Loading ONNX file from path {onnx_file_path}...')
82+
assert os.path.exists(onnx_file_path), f'cannot find {onnx_file_path}'
83+
with open(onnx_file_path, 'rb') as fr:
84+
if not parser.parse(fr.read()):
85+
print ('ERROR: Failed to parse the ONNX file.')
86+
for error in range(parser.num_errors):
87+
print (parser.get_error(error))
88+
assert False
89+
90+
# build settings
91+
builder.max_batch_size = 128
92+
config.max_workspace_size = 1 << 30 # 1G
93+
if args.quant == 'fp16':
94+
config.set_flag(trt.BuilderFlag.FP16)
95+
96+
print("Start to build Engine")
97+
plan = builder.build_serialized_network(network, config)
98+
engine = runtime.deserialize_cuda_engine(plan)
99+
return engine
100+
101+
102+
def serialize_engine_to_file(engine, savepth):
103+
plan = engine.serialize()
104+
with open(savepth, "wb") as fw:
105+
fw.write(plan)
106+
107+
108+
def deserialize_engine_from_file(savepth):
109+
with open(savepth, 'rb') as fr, trt.Runtime(TRT_LOGGER) as runtime:
110+
engine = runtime.deserialize_cuda_engine(fr.read())
111+
return engine
112+
113+
114+
def main():
115+
if args.command == 'compile':
116+
engine = build_engine_from_onnx(args.onnx)
117+
serialize_engine_to_file(engine, args.savepth)
118+
119+
elif args.command == 'run':
120+
engine = deserialize_engine_from_file(args.mdpth)
121+
122+
ishape = engine.get_binding_shape(0)
123+
img, (orgH, orgW) = get_image(args.impth, ishape[2:])
124+
125+
## create engine and allocate bffers
126+
(
127+
stream,
128+
h_input,
129+
d_input,
130+
h_outputs,
131+
d_outputs,
132+
) = allocate_buffers(engine)
133+
ctx.push()
134+
context = engine.create_execution_context()
135+
ctx.pop()
136+
bds = [int(d_input), ] + [int(el) for el in d_outputs]
137+
138+
h_input = np.ascontiguousarray(img)
139+
cuda.memcpy_htod_async(d_input, h_input, stream)
140+
context.execute_async(
141+
bindings=bds, stream_handle=stream.handle)
142+
for h_output, d_output in zip(h_outputs, d_outputs):
143+
cuda.memcpy_dtoh_async(h_output, d_output, stream)
144+
stream.synchronize()
145+
146+
out = palette[h_outputs[0]]
147+
outshape = engine.get_binding_shape(1)
148+
H, W = outshape[1], outshape[2]
149+
out = out.reshape(H, W, 3)
150+
out = cv2.resize(out, (orgW, orgH))
151+
cv2.imwrite(args.outpth, out)
152+
153+
154+
155+
if __name__ == '__main__':
156+
main()
157+

0 commit comments

Comments
 (0)