Skip to content

Commit ea69d95

Browse files
authored
Merge pull request #155 from Visual-Behavior/trt_export_from_onnx
Trt export from onnx
2 parents 567a376 + 27214bd commit ea69d95

File tree

2 files changed

+109
-13
lines changed

2 files changed

+109
-13
lines changed

alonet/torch2trt/base_exporter.py

+75-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import onnx
1313
import onnx_graphsurgeon as gs
1414
import tensorrt as trt
15-
15+
import pycuda.driver as cuda
1616
prod_package_error = None
1717
except Exception as prod_package_error:
1818
pass
@@ -21,6 +21,8 @@
2121
from contextlib import redirect_stdout, ExitStack
2222
from alonet.torch2trt.onnx_hack import scope_name_workaround, get_scope_names, rename_tensors_
2323
from alonet.torch2trt import TRTEngineBuilder, TRTExecutor, utils
24+
from alonet.torch2trt.utils import get_nodes_by_op, rename_nodes_
25+
2426

2527

2628
class BaseTRTExporter:
@@ -51,6 +53,7 @@ def __init__(
5153
operator_export_type=None,
5254
dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = None,
5355
opt_profiles: Dict[str, Tuple[List[int]]] = None,
56+
skip_adapt_graph=False,
5457
**kwargs,
5558
):
5659
"""
@@ -108,6 +111,7 @@ def __init__(
108111
self.custom_opset = None # to be redefine in child class if needed
109112
self.use_scope_names = use_scope_names
110113
self.operator_export_type = operator_export_type
114+
self.skip_adapt_graph = skip_adapt_graph
111115
if dynamic_axes is not None:
112116
assert opt_profiles is not None, "If dynamic_axes are to be used, opt_profiles must be provided"
113117
assert isinstance(dynamic_axes, dict)
@@ -117,13 +121,19 @@ def __init__(
117121
onnx_dir = os.path.split(onnx_path)[0]
118122
onnx_file_name = os.path.split(onnx_path)[1]
119123
model_name = onnx_file_name.split(".")[0]
120-
self.adapted_onnx_path = os.path.join(onnx_dir, "trt_" + onnx_file_name)
124+
125+
if not self.skip_adapt_graph:
126+
self.adapted_onnx_path = os.path.join(onnx_dir, "trt_" + onnx_file_name)
127+
else:
128+
self.adapted_onnx_path = os.path.join(onnx_dir, onnx_file_name)
129+
121130
self.engine_path = os.path.join(onnx_dir, model_name + f"_{precision.lower()}.engine")
122131

123132
if self.verbose:
124133
trt_logger = trt.Logger(trt.Logger.VERBOSE)
125134
else:
126135
trt_logger = trt.Logger(trt.Logger.WARNING)
136+
127137
self.engine_builder = TRTEngineBuilder(self.adapted_onnx_path, logger=trt_logger, opt_profiles=opt_profiles)
128138

129139
if precision.lower() == "fp32":
@@ -147,15 +157,59 @@ def build_torch_model(self):
147157
pass
148158
raise Exception("Child class should implement this method")
149159

160+
150161
def adapt_graph(self, graph):
151162
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
152163
153164
Returns
154165
-------
155166
graph: onnx_graphsurgeon.Graph
156167
"""
157-
pass
158-
raise Exception("Child class should implement this method")
168+
return graph
169+
170+
def _adapt_graph(self, graph):
171+
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
172+
173+
Returns
174+
-------
175+
graph: onnx_graphsurgeon.Graph
176+
"""
177+
clip_nodes = get_nodes_by_op("Clip", graph)
178+
def handle_op_Clip(node: gs.Node):
179+
max_constant = np.array(np.finfo(np.float32).max, dtype=np.float32)
180+
if "value" in node.inputs[1].i().inputs[0].attrs:
181+
min_constant = node.inputs[1].i().inputs[0].attrs["value"].values.astype(np.float32)
182+
if len(node.inputs[2].inputs) > 0:
183+
max_constant = node.inputs[2].i().inputs[0].attrs["value"].values.astype(np.float32)
184+
elif "to" in node.inputs[1].i().inputs[0].attrs:
185+
min_constant = np.array(np.finfo(np.float32).min, dtype=np.float32)
186+
else:
187+
raise Exception("Error")
188+
node.inputs.pop(1)
189+
node.inputs.insert(1, gs.Constant(name=node.name + "_min", values=min_constant))
190+
node.inputs.pop(2)
191+
node.inputs.insert(2, gs.Constant(name=node.name + "_max", values=max_constant))
192+
193+
for n in clip_nodes:
194+
handle_op_Clip(n)
195+
196+
from onnxsim import simplify
197+
model = onnx.load(self.onnx_path)
198+
check = False
199+
model_simp, check = simplify(model)
200+
201+
if check:
202+
print("\n[INFO] Simplified ONNX model validated. Graph optimized...")
203+
graph = gs.import_onnx(model_simp)
204+
graph.toposort()
205+
graph.cleanup()
206+
else:
207+
print("\n[INFO] ONNX model was not validated.")
208+
209+
210+
# Call the child class for specific graph adapation
211+
graph = self.adapt_graph(graph)
212+
return graph
159213

160214
def prepare_sample_inputs(self) -> Tuple[Tuple[torch.Tensor], Dict[str, Union[torch.Tensor, None]]]:
161215
"""
@@ -247,6 +301,7 @@ def _torch2onnx(self):
247301
number2scope = get_scope_names(onnx_export_log, strict=False)
248302
graph = gs.import_onnx(onnx.load(self.onnx_path))
249303
graph = rename_tensors_(graph, number2scope, verbose=True)
304+
graph = rename_nodes_(graph, True)
250305
onnx.save(gs.export_onnx(graph), self.onnx_path)
251306

252307
print("Saved ONNX at:", self.onnx_path)
@@ -265,15 +320,15 @@ def _onnx2engine(self, **kwargs):
265320
if prod_package_error is not None:
266321
raise prod_package_error
267322

268-
graph = gs.import_onnx(onnx.load(self.onnx_path))
269-
graph.toposort()
270-
271-
# === Modify ONNX graph for TensorRT compability
272-
graph = self.adapt_graph(graph, **kwargs)
273-
utils.print_graph_io(graph)
323+
if not self.skip_adapt_graph:
324+
graph = gs.import_onnx(onnx.load(self.onnx_path))
325+
graph.toposort()
274326

275-
# === Export adapted onnx for TRT engine
276-
onnx.save(gs.export_onnx(graph), self.adapted_onnx_path)
327+
# === Modify ONNX graph for TensorRT compability
328+
graph = self._adapt_graph(graph, **kwargs)
329+
utils.print_graph_io(graph)
330+
# === Export adapted onnx for TRT engine
331+
onnx.save(gs.export_onnx(graph), self.adapted_onnx_path)
277332

278333
# === Build engine
279334
self.engine_builder.export_engine(self.engine_path)
@@ -286,7 +341,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
286341
threshold = 1e-1
287342
check = True
288343
# Get engine info
289-
model = TRTExecutor(engine)
344+
model = TRTExecutor(engine, stream=cuda.Stream())
290345
model.print_bindings_info()
291346
# Prepare engine inputs
292347
for i in range(len(sample_inputs)):
@@ -302,6 +357,7 @@ def sanity_check(self, engine, sample_inputs, sample_outputs):
302357
m_outputs = model.execute()
303358
print("==== Absolute / relavtive error:")
304359
for out in m_outputs:
360+
print('out', m_outputs[out])
305361
diff = m_outputs[out].astype(float) - sample_outputs[out].astype(float)
306362
abs_err = np.abs(diff)
307363
rel_err = np.abs(diff / (sample_outputs[out] + 1e-6)) # Avoid div by zero
@@ -332,7 +388,13 @@ def add_argparse_args(parent_parser):
332388
default=None,
333389
help="/path/onnx/will/be/exported, by default set as ~/.aloception/weights/MODEL/MODEL.onnx",
334390
)
391+
parser.add_argument("--skip_adapt_graph", action="store_true", help="Skip the adapt graph")
335392
parser.add_argument("--batch_size", type=int, default=1, help="Engine batch size, default = 1")
336393
parser.add_argument("--precision", type=str, default="fp32", help="fp32/fp16/mix, default FP32")
337394
parser.add_argument("--verbose", action="store_true", help="Helpful when debugging")
395+
parser.add_argument(
396+
"--use_scope_names",
397+
action="store_true",
398+
help="Save scope names in onnx, to get profiles in inference by default %(default)s",
399+
)
338400
return parent_parser

alonet/torch2trt/utils.py

+34
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,37 @@ def execute_sync(context, bindings, inputs, outputs):
368368
for out in outputs:
369369
out.host = out.host.reshape(out.shape)
370370
return [out.host for out in outputs]
371+
372+
373+
374+
def rename_nodes_(graph, verbose=False):
375+
376+
dont_rename = [v.name for v in graph.inputs + graph.outputs]
377+
378+
for node in graph.nodes:
379+
if node.name not in dont_rename:
380+
# Replace name by output name to include in profiling
381+
node.name = node.outputs[0].name
382+
# If the node does not have name, try to replace by inputs tensors to it
383+
try:
384+
id_node = int(node.name)
385+
node_is_int = True
386+
except:
387+
node_is_int = False
388+
389+
if node_is_int:
390+
for inode in node.inputs:
391+
try: # Only for named inputs
392+
int(inode.name)
393+
inode_is_int = True
394+
except:
395+
inode_is_int = False
396+
397+
# Input named, change tensor name
398+
if not inode_is_int:
399+
new_name = inode.name + "_" + str(id_node)
400+
if verbose:
401+
print(f" changed {node.name} to {new_name}")
402+
node.name = new_name
403+
404+
return graph

0 commit comments

Comments
 (0)