-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_tensorrt.py
More file actions
170 lines (134 loc) · 4.37 KB
/
export_tensorrt.py
File metadata and controls
170 lines (134 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Export ONNX Model to TensorRT Engine.
Optimizes model for NVIDIA GPU inference.
"""
import argparse
from pathlib import Path
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
from loguru import logger
def build_tensorrt_engine(
onnx_path: str,
engine_path: str,
fp16_mode: bool = True,
int8_mode: bool = False,
max_batch_size: int = 32,
workspace_size: int = 1 << 30, # 1GB
):
"""
Build TensorRT engine from ONNX model.
Args:
onnx_path: Path to ONNX model
engine_path: Output engine path
fp16_mode: Enable FP16 precision
int8_mode: Enable INT8 precision
max_batch_size: Maximum batch size
workspace_size: Workspace size in bytes
"""
if not TENSORRT_AVAILABLE:
raise RuntimeError("TensorRT not available")
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
logger.info("Building TensorRT engine...")
logger.info(f"ONNX model: {onnx_path}")
logger.info(f"FP16 mode: {fp16_mode}")
logger.info(f"INT8 mode: {int8_mode}")
logger.info(f"Max batch size: {max_batch_size}")
# Create builder
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, TRT_LOGGER)
# Parse ONNX
logger.info("Parsing ONNX model...")
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
logger.error(parser.get_error(error))
raise RuntimeError("Failed to parse ONNX model")
# Configure builder
config = builder.create_builder_config()
config.max_workspace_size = workspace_size
if fp16_mode and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
logger.info("✓ FP16 mode enabled")
if int8_mode and builder.platform_has_fast_int8:
config.set_flag(trt.BuilderFlag.INT8)
logger.info("✓ INT8 mode enabled")
# Note: INT8 requires calibration data
# Build engine
logger.info("Building engine (this may take a while)...")
engine = builder.build_engine(network, config)
if engine is None:
raise RuntimeError("Failed to build TensorRT engine")
# Serialize and save
logger.info(f"Saving engine to {engine_path}")
with open(engine_path, "wb") as f:
f.write(engine.serialize())
# Get engine size
engine_size_mb = Path(engine_path).stat().st_size / (1024 * 1024)
logger.info(f"✓ TensorRT engine built successfully!")
logger.info(f"✓ Engine saved to: {engine_path}")
logger.info(f"✓ Engine size: {engine_size_mb:.2f} MB")
return engine_path
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Export ONNX to TensorRT")
parser.add_argument(
"--onnx",
type=str,
required=True,
help="Path to ONNX model",
)
parser.add_argument(
"--output",
type=str,
default="exports/model.engine",
help="Output TensorRT engine path",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Enable FP16 precision",
)
parser.add_argument(
"--int8",
action="store_true",
help="Enable INT8 precision",
)
parser.add_argument(
"--max-batch-size",
type=int,
default=32,
help="Maximum batch size",
)
parser.add_argument(
"--workspace-size",
type=int,
default=1073741824, # 1GB
help="Workspace size in bytes",
)
return parser.parse_args()
def main():
"""Main export function."""
args = parse_args()
# Create output directory
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Build engine
build_tensorrt_engine(
onnx_path=args.onnx,
engine_path=str(output_path),
fp16_mode=args.fp16,
int8_mode=args.int8,
max_batch_size=args.max_batch_size,
workspace_size=args.workspace_size,
)
logger.info("Export completed successfully!")
if __name__ == "__main__":
main()