Skip to content

Commit 81abc62

Browse files
committed
add profile shape parsing
1 parent 8192a85 commit 81abc62

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

+47
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,53 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
11001100
profile_min_shapes = info.profile_min_shapes;
11011101
profile_max_shapes = info.profile_max_shapes;
11021102
profile_opt_shapes = info.profile_opt_shapes;
1103+
1104+
/*
1105+
* Parse explicit min/max/opt profile shapes from provider options.
1106+
*
1107+
* The format of min/max/opt profile shapes is defined as below:
1108+
* "input1:dim1xdim2...,input2:dim1xdim2...,...,input1:dim3xdim4...,input2:dim3xdim4...,..."
1109+
*
1110+
* (Note: if multiple shapes with same input name are specified, TRT EP will consider them as multiple profiles.
1111+
* Please refer to ParserProfileShapes() for more details)
1112+
*
1113+
*/
1114+
bool status = true;
1115+
if (status) {
1116+
status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_);
1117+
if (!status) {
1118+
profile_min_shapes_.clear();
1119+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'";
1120+
}
1121+
}
1122+
1123+
if (status) {
1124+
status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_);
1125+
if (!status) {
1126+
profile_max_shapes_.clear();
1127+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'";
1128+
}
1129+
}
1130+
1131+
if (status) {
1132+
status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_);
1133+
if (!status) {
1134+
profile_opt_shapes_.clear();
1135+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'";
1136+
}
1137+
}
1138+
1139+
if (status) {
1140+
status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_);
1141+
if (!status) {
1142+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile.";
1143+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you.";
1144+
profile_min_shapes_.clear();
1145+
profile_max_shapes_.clear();
1146+
profile_opt_shapes_.clear();
1147+
}
1148+
}
1149+
11031150
cuda_graph_enable_ = info.cuda_graph_enable;
11041151
op_types_to_exclude_ = info.op_types_to_exclude;
11051152

0 commit comments

Comments
 (0)