Skip to content

Commit c8e9d32

Browse files
No public description
PiperOrigin-RevId: 758408976
1 parent 511bb82 commit c8e9d32

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Prediction from the Triton server."""
16+
17+
from typing import Any
18+
import cv2
19+
import numpy as np
20+
from tritonclient import grpc as triton_grpc
21+
22+
23+
_OUTPUT_KEYS = (
24+
'detection_classes',
25+
'detection_masks',
26+
'detection_boxes',
27+
'image_info',
28+
'num_detections',
29+
'detection_scores',
30+
)
31+
_OUTPUTS = tuple(triton_grpc.InferRequestedOutput(key) for key in _OUTPUT_KEYS)
32+
33+
34+
def prepare_image(
35+
path: str, height: int, width: int
36+
) -> tuple[triton_grpc.InferInput, np.ndarray, np.ndarray]:
37+
"""Prepares an image and converts it to an input for a Triton model server.
38+
39+
Args:
40+
path: The file path to the image that needs to be processed.
41+
height: The height of the image to be resized.
42+
width: The width of the image to be resized.
43+
44+
Returns:
45+
A tuple with the triton InferInput and both the original and resized
46+
image.
47+
"""
48+
image_bgr = cv2.imread(path)
49+
image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
50+
image_resized = cv2.resize(
51+
image, (width, height), interpolation=cv2.INTER_AREA
52+
)
53+
expanded_image = np.expand_dims(image_resized, axis=0)
54+
inputs = triton_grpc.InferInput(
55+
'inputs', expanded_image.shape, datatype='UINT8'
56+
)
57+
inputs.set_data_from_numpy(expanded_image)
58+
return inputs, image, image_resized
59+
60+
61+
def infer(
62+
model_name: str, inputs: triton_grpc.InferInput
63+
) -> dict[str, Any]:
64+
"""Wraps inference and converts the result to a dictionary of output keys.
65+
66+
Args:
67+
model_name: Model name in Triton Server.
68+
inputs: The input data for inference.
69+
70+
Returns:
71+
A dictionary of output keys and their corresponding values from
72+
InferResult.
73+
"""
74+
result = triton_grpc.InferenceServerClient(url='localhost:8001').infer(
75+
model_name=model_name, inputs=[inputs], outputs=_OUTPUTS
76+
)
77+
if result:
78+
return {key: result.as_numpy(key) for key in _OUTPUT_KEYS}
79+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest import mock
17+
import cv2
18+
import numpy as np
19+
from tritonclient import grpc as triton_grpc
20+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import triton_server_inference
21+
22+
# Create a small 1x4 BGR (open cv default) test image
23+
BGR_TEST_IMAGE = np.zeros((1, 4, 3), dtype=np.uint8)
24+
BGR_TEST_IMAGE[0, 0] = [0, 0, 255] # Red in BGR
25+
BGR_TEST_IMAGE[0, 1] = [0, 255, 0]
26+
BGR_TEST_IMAGE[0, 2] = [255, 0, 0] # Blue in BGR
27+
BGR_TEST_IMAGE[0, 3] = [0, 255, 255]
28+
29+
30+
class TestTritonPrediction(unittest.TestCase):
31+
32+
@mock.patch.object(cv2, 'imread')
33+
def test_input_conversion_to_rgb(self, mock_imread):
34+
mock_imread.return_value = BGR_TEST_IMAGE
35+
36+
_, test_image, _ = (
37+
triton_server_inference.prepare_image('/path/test_img.jpg', 5, 5)
38+
)
39+
40+
# Check that a single BRG pixel is converted to RGB
41+
self.assertEqual(test_image[0, 0].tolist(), [255, 0, 0])
42+
43+
@mock.patch.object(cv2, 'imread')
44+
def test_input_image_resized(self, mock_imread):
45+
mock_imread.return_value = BGR_TEST_IMAGE
46+
47+
_, _, test_image_resized = (
48+
triton_server_inference.prepare_image('/path/test_img.jpg', 5, 5)
49+
)
50+
51+
self.assertEqual(test_image_resized.shape, (5, 5, 3))
52+
53+
@mock.patch.object(cv2, 'imread')
54+
def test_batch_dimension_prepended_to_triton_input(self, mock_imread):
55+
mock_imread.return_value = BGR_TEST_IMAGE
56+
57+
test_triton_input, _, _ = (
58+
triton_server_inference.prepare_image('/path/test_img.jpg', 5, 5)
59+
)
60+
61+
self.assertEqual(test_triton_input.shape(), [1, 5, 5, 3])
62+
63+
@mock.patch.object(cv2, 'imread')
64+
def test_image_converted_to_infer_input(self, mock_imread):
65+
mock_imread.return_value = BGR_TEST_IMAGE
66+
67+
test_triton_input, _, _ = (
68+
triton_server_inference.prepare_image('/path/test_img.jpg', 5, 5)
69+
)
70+
71+
self.assertIsInstance(test_triton_input, triton_grpc.InferInput)
72+
73+
@mock.patch.object(triton_grpc.InferInput, 'set_data_from_numpy')
74+
@mock.patch.object(cv2, 'imread')
75+
def test_infer_input_set(self, mock_imread, mock_set_data_from_numpy):
76+
mock_imread.return_value = BGR_TEST_IMAGE
77+
78+
triton_server_inference.prepare_image('/path/test_img.jpg', 5, 5)
79+
80+
# Check that the set_data_from_numpy method is called once. Triton
81+
# InferInput data is a black-box, so we just check that it was set.
82+
mock_set_data_from_numpy.assert_called_once()
83+
84+
@mock.patch.object(triton_grpc.InferenceServerClient, 'infer')
85+
def test_inference_output_converted_to_dict(self, mock_query_model):
86+
test_output_data = np.array([[1, 0]])
87+
mock_infer_result = mock.create_autospec(
88+
triton_grpc.InferResult, instance=True
89+
)
90+
mock_infer_result.as_numpy = lambda key: test_output_data
91+
mock_query_model.return_value = mock_infer_result
92+
93+
result = triton_server_inference.infer('test_model', mock.MagicMock())
94+
95+
for key in triton_server_inference._OUTPUT_KEYS:
96+
self.assertIn(key, result)
97+
self.assertIsInstance(result[key], np.ndarray)
98+
99+
100+
if __name__ == '__main__':
101+
unittest.main()

0 commit comments

Comments
 (0)