Skip to content

Commit b03c77b

Browse files
committed
move preprocessing to server side as a backend, and add http client
1 parent befab8b commit b03c77b

18 files changed

+1840
-184
lines changed

tis/README.md

+36-8
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,38 @@ $ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /
2323
$ cp -riv ./model.onnx tis/models/bisenetv2/1
2424
```
2525

26-
#### 2. start service
27-
We start serving with docker:
26+
#### 2. prepare the preprocessing backend
27+
We can use either python backend or cpp backend for preprocessing in the server side.
28+
Firstly, we pull the docker image, and start a serving container:
2829
```
29-
$ docker pull nvcr.io/nvidia/tritonserver:21.10-py3
30-
$ docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models
30+
$ docker pull nvcr.io/nvidia/tritonserver:22.07-py3
31+
$ docker run -it --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models -v /path/to/BiSeNet/:/BiSeNet nvcr.io/nvidia/tritonserver:21.10-py3 bash
3132
```
33+
From here on, we are in the container environment. Let's prepare the backends in the container:
34+
```
35+
# ln -s /usr/local/bin/pip3.8 /usr/bin/pip3.8
36+
# /usr/bin/python3 -m pip install pillow
37+
# apt update && apt install rapidjson-dev libopencv-dev
38+
```
39+
Then we download cmake 3.22 and unzip in the container, we use this cmake 3.22 in the following operations.
40+
We compile c++ backends:
41+
```
42+
# cp -riv /BiSeNet/tis/self_backend /opt/tritonserver/backends
43+
# chmod 777 /opt/tritonserver/backends/self_backend
44+
# cd /opt/tritonserver/backends/self_backend
45+
# mkdir -p build && cd build
46+
# cmake .. && make -j4
47+
# mv -iuv libtriton_self_backend.so ..
48+
```
49+
Utils now, we should have backends prepared.
50+
3251

52+
53+
#### 3. start service
54+
We start the server in the docker container, following the above steps:
55+
```
56+
# tritonserver --model-repository=/models
57+
```
3358
In general, the service would start now. You can check whether service has started by:
3459
```
3560
$ curl -v localhost:8000/v2/health/ready
@@ -38,10 +63,12 @@ $ curl -v localhost:8000/v2/health/ready
3863
By default, we use gpu 0 and gpu 1, you can change configurations in the `config.pbtxt` file.
3964

4065

41-
### Client
66+
### Request with client
4267

4368
We call the model service with both python and c++ method.
4469

70+
From here on, we are at the client machine, rather than the server docker container.
71+
4572

4673
#### 1. python method
4774

@@ -50,10 +77,11 @@ Firstly, we need to install dependency package:
5077
$ python -m pip install tritonclient[all]==2.15.0
5178
```
5279

53-
Then we can run the script:
80+
Then we can run the script for both http request and grpc request:
5481
```
5582
$ cd BiSeNet/tis
56-
$ python client.py
83+
$ python client_http.py # if you want to use http client
84+
$ python client_grpc.py # if you want to use grpc client
5785
```
5886

5987
This would generate a result file named `res.jpg` in `BiSeNet/tis` directory.
@@ -92,4 +120,4 @@ Finally, we run the client and see a result file named `res.jpg` generated:
92120

93121
### In the end
94122

95-
This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and model pipeline. If you have interest on this, you can learn more in the official document.
123+
This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and dynamic batching. If you have interests on this, you can learn more in the official document.

tis/client_backend.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
2+
3+
import argparse
4+
import sys
5+
import numpy as np
6+
import cv2
7+
import gevent.ssl
8+
9+
import tritonclient.http as httpclient
10+
from tritonclient.utils import InferenceServerException
11+
12+
13+
np.random.seed(123)
14+
palette = np.random.randint(0, 256, (100, 3))
15+
16+
17+
url = '10.128.61.8:8000'
18+
# url = '127.0.0.1:8000'
19+
model_name = 'preprocess_cpp'
20+
model_version = '1'
21+
inp_name = 'raw_img_bytes'
22+
outp_name = 'processed_img'
23+
inp_dtype = 'UINT8'
24+
impth = '../example.png'
25+
mean = [0.3257, 0.3690, 0.3223] # city, rgb
26+
std = [0.2112, 0.2148, 0.2115]
27+
28+
29+
## prepare image and mean/std
30+
inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...]
31+
mean = np.array(mean, dtype=np.float32)[None, ...]
32+
std = np.array(std, dtype=np.float32)[None, ...]
33+
inputs = []
34+
inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype))
35+
inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32'))
36+
inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32'))
37+
inputs[0].set_data_from_numpy(inp_data, binary_data=True)
38+
inputs[1].set_data_from_numpy(mean, binary_data=True)
39+
inputs[2].set_data_from_numpy(std, binary_data=True)
40+
41+
## client
42+
triton_client = httpclient.InferenceServerClient(
43+
url=url, verbose=False, concurrency=32)
44+
45+
## infer
46+
# sync
47+
# results = triton_client.infer(model_name, inputs)
48+
49+
50+
# async
51+
# results = triton_client.async_infer(
52+
# model_name,
53+
# inputs,
54+
# outputs=None,
55+
# query_params=None,
56+
# headers=None,
57+
# request_compression_algorithm=None,
58+
# response_compression_algorithm=None)
59+
# results = results.get_result() # async infer only
60+
61+
62+
## dynamic batching, this is not allowed, since different pictures has different raw size
63+
results = []
64+
for i in range(10):
65+
r = triton_client.async_infer(
66+
model_name,
67+
inputs,
68+
outputs=None,
69+
query_params=None,
70+
headers=None,
71+
request_compression_algorithm=None,
72+
response_compression_algorithm=None)
73+
results.append(r)
74+
for i in range(10):
75+
results[i].get_result()
76+
results = results[i]
77+
78+
79+
# get output
80+
outp = results.as_numpy(outp_name).squeeze()
81+
print(outp.shape)

tis/client.py tis/client_grpc.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,36 @@
1313

1414

1515

16-
# url = '10.128.61.7:8001'
17-
url = '127.0.0.1:8001'
18-
model_name = 'bisenetv2'
16+
url = '10.128.61.8:8001'
17+
# url = '127.0.0.1:8001'
18+
model_name = 'bisenetv1'
1919
model_version = '1'
20-
inp_name = 'input_image'
20+
inp_name = 'raw_img_bytes'
2121
outp_name = 'preds'
22-
inp_dtype = 'FP32'
22+
inp_dtype = 'UINT8'
2323
outp_dtype = np.int64
24-
inp_shape = [1, 3, 1024, 2048]
25-
outp_shape = [1024, 2048]
2624
impth = '../example.png'
2725
mean = [0.3257, 0.3690, 0.3223] # city, rgb
2826
std = [0.2112, 0.2148, 0.2115]
2927

3028

29+
## input data and mean/std
30+
inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...]
31+
mean = np.array(mean, dtype=np.float32)[None, ...]
32+
std = np.array(std, dtype=np.float32)[None, ...]
33+
inputs = [service_pb2.ModelInferRequest().InferInputTensor() for _ in range(3)]
34+
inputs[0].name = inp_name
35+
inputs[0].datatype = inp_dtype
36+
inputs[0].shape.extend(inp_data.shape)
37+
inputs[1].name = 'channel_mean'
38+
inputs[1].datatype = 'FP32'
39+
inputs[1].shape.extend(mean.shape)
40+
inputs[2].name = 'channel_std'
41+
inputs[2].datatype = 'FP32'
42+
inputs[2].shape.extend(std.shape)
43+
inp_bytes = [inp_data.tobytes(), mean.tobytes(), std.tobytes()]
44+
45+
3146
option = [
3247
('grpc.max_receive_message_length', 1073741824),
3348
('grpc.max_send_message_length', 1073741824),
@@ -52,37 +67,22 @@
5267
request.model_name = model_name
5368
request.model_version = model_version
5469

55-
inp = service_pb2.ModelInferRequest().InferInputTensor()
56-
inp.name = inp_name
57-
inp.datatype = inp_dtype
58-
inp.shape.extend(inp_shape)
59-
60-
61-
mean = np.array(mean).reshape(1, 1, 3)
62-
std = np.array(std).reshape(1, 1, 3)
63-
im = cv2.imread(impth)[:, :, ::-1]
64-
im = cv2.resize(im, dsize=tuple(inp_shape[-1:-3:-1]))
65-
im = ((im / 255.) - mean) / std
66-
im = im[None, ...].transpose(0, 3, 1, 2)
67-
inp_bytes = im.astype(np.float32).tobytes()
68-
6970
request.ClearField("inputs")
7071
request.ClearField("raw_input_contents")
71-
request.inputs.extend([inp,])
72-
request.raw_input_contents.extend([inp_bytes,])
73-
72+
request.inputs.extend(inputs)
73+
request.raw_input_contents.extend(inp_bytes)
7474

75-
outp = service_pb2.ModelInferRequest().InferRequestedOutputTensor()
76-
outp.name = outp_name
77-
request.outputs.extend([outp,])
7875

7976
# sync
80-
# resp = grpc_stub.ModelInfer(request).raw_output_contents[0]
77+
# resp = grpc_stub.ModelInfer(request)
8178
# async
8279
resp = grpc_stub.ModelInfer.future(request)
83-
resp = resp.result().raw_output_contents[0]
80+
resp = resp.result()
81+
82+
outp_bytes = resp.raw_output_contents[0]
83+
outp_shape = resp.outputs[0].shape
8484

85-
out = np.frombuffer(resp, dtype=outp_dtype).reshape(*outp_shape)
85+
out = np.frombuffer(outp_bytes, dtype=outp_dtype).reshape(*outp_shape).squeeze()
8686

8787
out = palette[out]
8888
cv2.imwrite('res.png', out)

tis/client_http.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
2+
3+
import argparse
4+
import sys
5+
import numpy as np
6+
import cv2
7+
import gevent.ssl
8+
9+
import tritonclient.http as httpclient
10+
from tritonclient.utils import InferenceServerException
11+
12+
13+
np.random.seed(123)
14+
palette = np.random.randint(0, 256, (100, 3))
15+
16+
17+
url = '10.128.61.8:8000'
18+
# url = '127.0.0.1:8000'
19+
model_name = 'bisenetv2'
20+
model_version = '1'
21+
inp_name = 'raw_img_bytes'
22+
outp_name = 'preds'
23+
inp_dtype = 'UINT8'
24+
impth = '../example.png'
25+
mean = [0.3257, 0.3690, 0.3223] # city, rgb
26+
std = [0.2112, 0.2148, 0.2115]
27+
28+
29+
## prepare image and mean/std
30+
inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...]
31+
mean = np.array(mean, dtype=np.float32)[None, ...]
32+
std = np.array(std, dtype=np.float32)[None, ...]
33+
inputs = []
34+
inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype))
35+
inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32'))
36+
inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32'))
37+
inputs[0].set_data_from_numpy(inp_data, binary_data=True)
38+
inputs[1].set_data_from_numpy(mean, binary_data=True)
39+
inputs[2].set_data_from_numpy(std, binary_data=True)
40+
41+
42+
## client
43+
triton_client = httpclient.InferenceServerClient(
44+
url=url, verbose=False, concurrency=32)
45+
46+
## infer
47+
# sync
48+
# results = triton_client.infer(model_name, inputs)
49+
50+
# async
51+
results = triton_client.async_infer(
52+
model_name,
53+
inputs,
54+
outputs=None,
55+
query_params=None,
56+
headers=None,
57+
request_compression_algorithm=None,
58+
response_compression_algorithm=None)
59+
results = results.get_result() # async infer only
60+
61+
# get output
62+
outp = results.as_numpy(outp_name).squeeze()
63+
out = palette[outp]
64+
cv2.imwrite('res.png', out)

tis/cpp_client/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required (VERSION 3.18)
22

33
project(Samples)
44

5-
set(CMAKE_CXX_FLAGS "-std=c++14 -O1")
5+
set(CMAKE_CXX_FLAGS "-std=c++14 -O2")
66
set(CMAKE_BUILD_TYPE Release)
77

88
set(CMAKE_PREFIX_PATH

0 commit comments

Comments
 (0)