Skip to content

Commit 43e0fbb

Browse files
authored
refactor: Use safer backend APIs (#304)
1 parent f1fec69 commit 43e0fbb

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/onnxruntime.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,12 +2238,16 @@ ModelInstanceState::SetInputTensors(
22382238
TRITONBACKEND_RequestInput(requests[idx], input_name, &input));
22392239
const int64_t* input_shape;
22402240
uint32_t input_dims_count;
2241+
int64_t element_cnt = 0;
22412242
RESPOND_AND_SET_NULL_IF_ERROR(
22422243
&((*responses)[idx]), TRITONBACKEND_InputProperties(
22432244
input, nullptr, nullptr, &input_shape,
22442245
&input_dims_count, nullptr, nullptr));
2246+
RESPOND_AND_SET_NULL_IF_ERROR(
2247+
&((*responses)[idx]),
2248+
GetElementCount(input_shape, input_dims_count, &element_cnt));
22452249

2246-
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
2250+
batchn_shape[0] += element_cnt;
22472251
}
22482252
}
22492253
// The shape for the entire input batch, [total_batch_size, ...]
@@ -2402,8 +2406,10 @@ ModelInstanceState::SetStringInputTensor(
24022406
expected_byte_sizes.push_back(0);
24032407
expected_element_cnts.push_back(0);
24042408
} else {
2405-
expected_element_cnts.push_back(
2406-
GetElementCount(input_shape, input_dims_count));
2409+
int64_t element_cnt = 0;
2410+
RETURN_IF_ERROR(
2411+
GetElementCount(input_shape, input_dims_count, &element_cnt));
2412+
expected_element_cnts.push_back(element_cnt);
24072413
expected_byte_sizes.push_back(input_byte_size);
24082414
}
24092415

@@ -2573,8 +2579,9 @@ ModelInstanceState::ReadOutputTensor(
25732579
ONNXTensorElementDataType type;
25742580
RETURN_IF_ORT_ERROR(ort_api->GetTensorElementType(type_and_shape, &type));
25752581
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
2576-
const size_t element_count = GetElementCount(batchn_shape);
2582+
int64_t element_count = 0;
25772583
size_t total_length = 0;
2584+
RETURN_IF_ERROR(GetElementCount(batchn_shape, &element_count));
25782585
RETURN_IF_ORT_ERROR(
25792586
ort_api->GetStringTensorDataLength(output_tensor, &total_length));
25802587

@@ -2776,7 +2783,9 @@ ModelInstanceState::SetStringBuffer(
27762783
(*batchn_shape)[0] = shape[0];
27772784
}
27782785

2779-
const size_t expected_element_cnt = GetElementCount(*batchn_shape);
2786+
int64_t expected_element_cnt = 0;
2787+
RESPOND_AND_SET_NULL_IF_ERROR(
2788+
&response, GetElementCount(*batchn_shape, &expected_element_cnt));
27802789

27812790
// If 'request' requested this output then copy it from
27822791
// 'content'. If it did not request this output then just skip it

0 commit comments

Comments
 (0)