12
12
import onnx
13
13
import onnx_graphsurgeon as gs
14
14
import tensorrt as trt
15
-
15
+ import pycuda . driver as cuda
16
16
prod_package_error = None
17
17
except Exception as prod_package_error :
18
18
pass
21
21
from contextlib import redirect_stdout , ExitStack
22
22
from alonet .torch2trt .onnx_hack import scope_name_workaround , get_scope_names , rename_tensors_
23
23
from alonet .torch2trt import TRTEngineBuilder , TRTExecutor , utils
24
+ from alonet .torch2trt .utils import get_nodes_by_op , rename_nodes_
25
+
24
26
25
27
26
28
class BaseTRTExporter :
@@ -51,6 +53,7 @@ def __init__(
51
53
operator_export_type = None ,
52
54
dynamic_axes : Union [Dict [str , Dict [int , str ]], Dict [str , List [int ]]] = None ,
53
55
opt_profiles : Dict [str , Tuple [List [int ]]] = None ,
56
+ skip_adapt_graph = False ,
54
57
** kwargs ,
55
58
):
56
59
"""
@@ -108,6 +111,7 @@ def __init__(
108
111
self .custom_opset = None # to be redefine in child class if needed
109
112
self .use_scope_names = use_scope_names
110
113
self .operator_export_type = operator_export_type
114
+ self .skip_adapt_graph = skip_adapt_graph
111
115
if dynamic_axes is not None :
112
116
assert opt_profiles is not None , "If dynamic_axes are to be used, opt_profiles must be provided"
113
117
assert isinstance (dynamic_axes , dict )
@@ -117,13 +121,19 @@ def __init__(
117
121
onnx_dir = os .path .split (onnx_path )[0 ]
118
122
onnx_file_name = os .path .split (onnx_path )[1 ]
119
123
model_name = onnx_file_name .split ("." )[0 ]
120
- self .adapted_onnx_path = os .path .join (onnx_dir , "trt_" + onnx_file_name )
124
+
125
+ if not self .skip_adapt_graph :
126
+ self .adapted_onnx_path = os .path .join (onnx_dir , "trt_" + onnx_file_name )
127
+ else :
128
+ self .adapted_onnx_path = os .path .join (onnx_dir , onnx_file_name )
129
+
121
130
self .engine_path = os .path .join (onnx_dir , model_name + f"_{ precision .lower ()} .engine" )
122
131
123
132
if self .verbose :
124
133
trt_logger = trt .Logger (trt .Logger .VERBOSE )
125
134
else :
126
135
trt_logger = trt .Logger (trt .Logger .WARNING )
136
+
127
137
self .engine_builder = TRTEngineBuilder (self .adapted_onnx_path , logger = trt_logger , opt_profiles = opt_profiles )
128
138
129
139
if precision .lower () == "fp32" :
@@ -147,15 +157,59 @@ def build_torch_model(self):
147
157
pass
148
158
raise Exception ("Child class should implement this method" )
149
159
160
+
150
161
def adapt_graph (self , graph ):
151
162
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
152
163
153
164
Returns
154
165
-------
155
166
graph: onnx_graphsurgeon.Graph
156
167
"""
157
- pass
158
- raise Exception ("Child class should implement this method" )
168
+ return graph
169
+
170
+ def _adapt_graph (self , graph ):
171
+ """Modify ONNX graph to ensure compability between ONNX and TensorRT
172
+
173
+ Returns
174
+ -------
175
+ graph: onnx_graphsurgeon.Graph
176
+ """
177
+ clip_nodes = get_nodes_by_op ("Clip" , graph )
178
+ def handle_op_Clip (node : gs .Node ):
179
+ max_constant = np .array (np .finfo (np .float32 ).max , dtype = np .float32 )
180
+ if "value" in node .inputs [1 ].i ().inputs [0 ].attrs :
181
+ min_constant = node .inputs [1 ].i ().inputs [0 ].attrs ["value" ].values .astype (np .float32 )
182
+ if len (node .inputs [2 ].inputs ) > 0 :
183
+ max_constant = node .inputs [2 ].i ().inputs [0 ].attrs ["value" ].values .astype (np .float32 )
184
+ elif "to" in node .inputs [1 ].i ().inputs [0 ].attrs :
185
+ min_constant = np .array (np .finfo (np .float32 ).min , dtype = np .float32 )
186
+ else :
187
+ raise Exception ("Error" )
188
+ node .inputs .pop (1 )
189
+ node .inputs .insert (1 , gs .Constant (name = node .name + "_min" , values = min_constant ))
190
+ node .inputs .pop (2 )
191
+ node .inputs .insert (2 , gs .Constant (name = node .name + "_max" , values = max_constant ))
192
+
193
+ for n in clip_nodes :
194
+ handle_op_Clip (n )
195
+
196
+ from onnxsim import simplify
197
+ model = onnx .load (self .onnx_path )
198
+ check = False
199
+ model_simp , check = simplify (model )
200
+
201
+ if check :
202
+ print ("\n [INFO] Simplified ONNX model validated. Graph optimized..." )
203
+ graph = gs .import_onnx (model_simp )
204
+ graph .toposort ()
205
+ graph .cleanup ()
206
+ else :
207
+ print ("\n [INFO] ONNX model was not validated." )
208
+
209
+
210
+ # Call the child class for specific graph adapation
211
+ graph = self .adapt_graph (graph )
212
+ return graph
159
213
160
214
def prepare_sample_inputs (self ) -> Tuple [Tuple [torch .Tensor ], Dict [str , Union [torch .Tensor , None ]]]:
161
215
"""
@@ -247,6 +301,7 @@ def _torch2onnx(self):
247
301
number2scope = get_scope_names (onnx_export_log , strict = False )
248
302
graph = gs .import_onnx (onnx .load (self .onnx_path ))
249
303
graph = rename_tensors_ (graph , number2scope , verbose = True )
304
+ graph = rename_nodes_ (graph , True )
250
305
onnx .save (gs .export_onnx (graph ), self .onnx_path )
251
306
252
307
print ("Saved ONNX at:" , self .onnx_path )
@@ -265,15 +320,15 @@ def _onnx2engine(self, **kwargs):
265
320
if prod_package_error is not None :
266
321
raise prod_package_error
267
322
268
- graph = gs .import_onnx (onnx .load (self .onnx_path ))
269
- graph .toposort ()
270
-
271
- # === Modify ONNX graph for TensorRT compability
272
- graph = self .adapt_graph (graph , ** kwargs )
273
- utils .print_graph_io (graph )
323
+ if not self .skip_adapt_graph :
324
+ graph = gs .import_onnx (onnx .load (self .onnx_path ))
325
+ graph .toposort ()
274
326
275
- # === Export adapted onnx for TRT engine
276
- onnx .save (gs .export_onnx (graph ), self .adapted_onnx_path )
327
+ # === Modify ONNX graph for TensorRT compability
328
+ graph = self ._adapt_graph (graph , ** kwargs )
329
+ utils .print_graph_io (graph )
330
+ # === Export adapted onnx for TRT engine
331
+ onnx .save (gs .export_onnx (graph ), self .adapted_onnx_path )
277
332
278
333
# === Build engine
279
334
self .engine_builder .export_engine (self .engine_path )
@@ -286,7 +341,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
286
341
threshold = 1e-1
287
342
check = True
288
343
# Get engine info
289
- model = TRTExecutor (engine )
344
+ model = TRTExecutor (engine , stream = cuda . Stream () )
290
345
model .print_bindings_info ()
291
346
# Prepare engine inputs
292
347
for i in range (len (sample_inputs )):
@@ -302,6 +357,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
302
357
m_outputs = model .execute ()
303
358
print ("==== Absolute / relavtive error:" )
304
359
for out in m_outputs :
360
+ print ('out' , m_outputs [out ])
305
361
diff = m_outputs [out ].astype (float ) - sample_outputs [out ].astype (float )
306
362
abs_err = np .abs (diff )
307
363
rel_err = np .abs (diff / (sample_outputs [out ] + 1e-6 )) # Avoid div by zero
@@ -332,7 +388,13 @@ def add_argparse_args(parent_parser):
332
388
default = None ,
333
389
help = "/path/onnx/will/be/exported, by default set as ~/.aloception/weights/MODEL/MODEL.onnx" ,
334
390
)
391
+ parser .add_argument ("--skip_adapt_graph" , action = "store_true" , help = "Skip the adapt graph" )
335
392
parser .add_argument ("--batch_size" , type = int , default = 1 , help = "Engine batch size, default = 1" )
336
393
parser .add_argument ("--precision" , type = str , default = "fp32" , help = "fp32/fp16/mix, default FP32" )
337
394
parser .add_argument ("--verbose" , action = "store_true" , help = "Helpful when debugging" )
395
+ parser .add_argument (
396
+ "--use_scope_names" ,
397
+ action = "store_true" ,
398
+ help = "Save scope names in onnx, to get profiles in inference by default %(default)s" ,
399
+ )
338
400
return parent_parser
0 commit comments