1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
2
import argparse
3
+ import json
4
+
3
5
import cv2
4
- from tritonclient .grpc import InferenceServerClient , InferInput , InferRequestedOutput
5
6
import numpy as np
6
- import json
7
+ from tritonclient .grpc import (InferenceServerClient , InferInput ,
8
+ InferRequestedOutput )
7
9
8
10
9
11
def parse_args ():
10
12
parser = argparse .ArgumentParser ()
11
- parser .add_argument ('model_name' , type = str ,
12
- help = 'model name' )
13
- parser .add_argument ('image' , type = str ,
14
- help = 'image path' )
13
+ parser .add_argument ('model_name' , type = str , help = 'model name' )
14
+ parser .add_argument ('image' , type = str , help = 'image path' )
15
15
return parser .parse_args ()
16
16
17
17
@@ -24,14 +24,16 @@ def __init__(self, url, model_name, model_version):
24
24
self ._client = InferenceServerClient (self ._url )
25
25
model_config = self ._client .get_model_config (self ._model_name ,
26
26
self ._model_version )
27
- model_metadata = self ._client .get_model_metadata (self . _model_name ,
28
- self ._model_version )
27
+ model_metadata = self ._client .get_model_metadata (
28
+ self . _model_name , self ._model_version )
29
29
print (f'[model config]:\n { model_config } ' )
30
30
print (f'[model metadata]:\n { model_metadata } ' )
31
31
self ._inputs = {input .name : input for input in model_metadata .inputs }
32
32
self ._input_names = list (self ._inputs )
33
33
self ._outputs = {
34
- output .name : output for output in model_metadata .outputs }
34
+ output .name : output
35
+ for output in model_metadata .outputs
36
+ }
35
37
self ._output_names = list (self ._outputs )
36
38
self ._outputs_req = [
37
39
InferRequestedOutput (name ) for name in self ._outputs
@@ -46,10 +48,10 @@ def infer(self, image, box):
46
48
results: dict, {name : numpy.array}
47
49
"""
48
50
49
- inputs = [InferInput ( self . _input_names [ 0 ], image . shape ,
50
- " UINT8" ),
51
- InferInput (self ._input_names [1 ], box .shape ,
52
- "BYTES" ) ]
51
+ inputs = [
52
+ InferInput ( self . _input_names [ 0 ], image . shape , ' UINT8' ),
53
+ InferInput (self ._input_names [1 ], box .shape , 'BYTES' )
54
+ ]
53
55
inputs [0 ].set_data_from_numpy (image )
54
56
inputs [1 ].set_data_from_numpy (box )
55
57
results = self ._client .infer (
@@ -72,20 +74,18 @@ def visualize(img, results):
72
74
cv2 .imwrite ('keypoint-detection.jpg' , img )
73
75
74
76
75
- if __name__ == " __main__" :
77
+ if __name__ == ' __main__' :
76
78
args = parse_args ()
77
79
model_name = args .model_name
78
- model_version = "1"
79
- url = " localhost:8001"
80
+ model_version = '1'
81
+ url = ' localhost:8001'
80
82
client = GRPCTritonClient (url , model_name , model_version )
81
83
img = cv2 .imread (args .image )
82
84
bbox = {
83
85
'type' : 'PoseBbox' ,
84
- 'value' : [
85
- {
86
- 'bbox' : [0.0 , 0.0 , img .shape [1 ], img .shape [0 ]]
87
- }
88
- ]
86
+ 'value' : [{
87
+ 'bbox' : [0.0 , 0.0 , img .shape [1 ], img .shape [0 ]]
88
+ }]
89
89
}
90
90
bbox = np .array ([json .dumps (bbox ).encode ('utf-8' )])
91
91
results = client .infer (img , bbox )
0 commit comments