Skip to content

Commit e64a490

Browse files
committed
Write unit test for Triton server
1 parent 0aca56a commit e64a490

File tree

2 files changed

+124
-124
lines changed

2 files changed

+124
-124
lines changed
Lines changed: 119 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,121 @@
1-
import json
2-
import time
3-
4-
import pytest
5-
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
6-
from apache_beam.ml.inference.vertex_ai_inference import VertexAITritonModelHandler
1+
import unittest
2+
from unittest.mock import MagicMock, patch
73
from apache_beam.ml.inference.base import PredictionResult
8-
9-
# Define a fake endpoint class to simulate a live AI Platform endpoint.
10-
class FakeEndpoint:
11-
def __init__(self):
12-
self.resource_name = "projects/test/locations/us-central1/endpoints/12345"
13-
self.deployed_model_id = "deployed_model_1"
14-
15-
def list_models(self):
16-
return [{"dummy": "model"}]
17-
18-
19-
# Define a fake PredictionServiceClient to simulate the behavior of the AI Platform client.
20-
class FakePredictionServiceClient:
21-
def __init__(self, client_options=None):
22-
self.client_options = client_options
23-
24-
def raw_predict(self, request):
25-
# Simulate a response. The true handler expects the response data to be a JSON
26-
# string containing an "outputs" field.
27-
fake_resp_content = {
28-
"outputs": [{"data": [["result1"], ["result2"]]}]
4+
from apache_beam.ml.inference.vertex_ai_inference import VertexAITritonModelHandler
5+
import json
6+
from google.cloud import aiplatform
7+
from google.cloud.aiplatform.gapic import PredictionServiceClient
8+
9+
class TestVertexAITritonModelHandler(unittest.TestCase):
10+
def setUp(self):
11+
"""Initialize the handler with test parameters before each test."""
12+
self.handler = VertexAITritonModelHandler(
13+
project_id="test-project",
14+
region="us-central1",
15+
endpoint_name="test-endpoint",
16+
name="input__0",
17+
location="us-central1",
18+
datatype="FP32",
19+
private=False,
20+
)
21+
self.handler.throttler = MagicMock()
22+
self.handler.throttler.throttle_request = MagicMock(return_value=False)
23+
self.handler.throttled_secs = MagicMock()
24+
25+
def test_load_model_public_endpoint(self):
26+
"""Test loading a public endpoint and verifying deployed models."""
27+
mock_endpoint = MagicMock(spec=aiplatform.Endpoint)
28+
mock_endpoint.list_models.return_value = [MagicMock()]
29+
with patch('google.cloud.aiplatform.Endpoint', return_value=mock_endpoint):
30+
model = self.handler.load_model()
31+
self.assertEqual(model, mock_endpoint)
32+
mock_endpoint.list_models.assert_called_once()
33+
34+
def test_load_model_no_deployed_models(self):
35+
"""Test that an endpoint with no deployed models raises ValueError."""
36+
mock_endpoint = MagicMock(spec=aiplatform.Endpoint)
37+
mock_endpoint.list_models.return_value = []
38+
with patch('google.cloud.aiplatform.Endpoint', return_value=mock_endpoint):
39+
with self.assertRaises(ValueError) as cm:
40+
self.handler.load_model()
41+
self.assertIn("no models deployed", str(cm.exception))
42+
43+
def test_get_request_payload_scalar(self):
44+
"""Test payload construction for a batch of scalar inputs."""
45+
batch = [1.0, 2.0, 3.0]
46+
expected_payload = {
47+
"inputs": [
48+
{
49+
"name": "input__0",
50+
"shape": [3, 1],
51+
"datatype": "FP32",
52+
"data": [1.0, 2.0, 3.0],
53+
}
54+
]
2955
}
30-
class FakeResponse:
31-
pass
32-
fake_response = FakeResponse()
33-
fake_response.data = json.dumps(fake_resp_content).encode("utf-8")
34-
return fake_response
35-
36-
37-
@pytest.fixture(autouse=True)
38-
def patch_prediction_service_client(monkeypatch):
39-
# Replace the PredictionServiceClient with our fake implementation.
40-
monkeypatch.setattr(
41-
"apache_beam.ml.inference.vertex_ai_inference.aiplatform.gapic.PredictionServiceClient",
42-
lambda client_options=None: FakePredictionServiceClient(client_options)
43-
)
44-
45-
46-
@pytest.fixture(autouse=True)
47-
def patch_retrieve_endpoint(monkeypatch):
48-
# Patch the _retrieve_endpoint() to always return our fake endpoint.
49-
monkeypatch.setattr(
50-
VertexAITritonModelHandler,
51-
"_retrieve_endpoint",
52-
lambda self: FakeEndpoint()
53-
)
54-
55-
56-
@pytest.fixture(autouse=True)
57-
def patch_convert_to_result(monkeypatch):
58-
# Override the conversion utility to a simple function for testing.
59-
from apache_beam.ml.inference import utils
60-
def fake_convert_to_result(batch, predictions, deployed_model_id):
61-
# In our simple conversion, we return a tuple: (input, prediction, deployed_model_id)
62-
return [(inp, pred, deployed_model_id) for inp, pred in zip(batch, predictions)]
63-
monkeypatch.setattr(utils, "_convert_to_result", fake_convert_to_result)
64-
65-
66-
def test_get_request(monkeypatch):
67-
"""
68-
Test that the get_request method constructs the request properly,
69-
calls the fake PredictionServiceClient, and returns the expected response data.
70-
"""
71-
# Create a TritonModelHandler instance.
72-
handler = VertexAITritonModelHandler(
73-
project_id="test-project",
74-
region="us-central1",
75-
endpoint_name="12345",
76-
name="input_field",
77-
location="us-central1",
78-
datatype="BYTES"
79-
)
80-
81-
# Create a fake endpoint to pass into get_request.
82-
fake_endpoint = FakeEndpoint()
83-
batch = ["test_input"]
84-
85-
# Call get_request.
86-
response_data = handler.get_request(batch, fake_endpoint, throttle_delay_secs=1, inference_agrs={})
87-
88-
# Validate the response structure.
89-
assert "outputs" in response_data
90-
outputs = response_data["outputs"]
91-
assert isinstance(outputs, list)
92-
assert "data" in outputs[0]
93-
# The fake client returns a list with two prediction results.
94-
assert outputs[0]["data"] == [["result1"], ["result2"]]
95-
96-
97-
def test_run_inference(monkeypatch):
98-
"""
99-
Test that run_inference converts the raw response into PredictionResult objects.
100-
"""
101-
handler = VertexAITritonModelHandler(
102-
project_id="test-project",
103-
region="us-central1",
104-
endpoint_name="12345",
105-
name="input_field",
106-
location="us-central1",
107-
datatype="BYTES"
108-
)
109-
110-
# load_model() is patched via _retrieve_endpoint so it returns our FakeEndpoint.
111-
model = handler.load_model()
112-
batch = ["input1", "input2"]
113-
114-
# Call run_inference and collect the results.
115-
results = list(handler.run_inference(batch, model, inference_args={"payload_config": {}}))
116-
117-
# With our fake response (2 predictions) and a batch of 2 inputs,
118-
# the conversion utility should yield 2 PredictionResult tuples.
119-
# For our fake_convert_to_result, each result is a tuple: (input, prediction, deployed_model_id).
120-
assert len(results) == len(batch)
121-
for inp, pred, deployed_id in results:
122-
assert inp in batch
123-
assert deployed_id == model.deployed_model_id
56+
model = MagicMock(resource_name="test-resource")
57+
mock_client = MagicMock(spec=PredictionServiceClient)
58+
mock_response = MagicMock()
59+
mock_response.data.decode.return_value = json.dumps({"outputs": [{"data": [0.5, 0.6, 0.7]}]})
60+
mock_client.raw_predict.return_value = mock_response
61+
62+
with patch('google.cloud.aiplatform.gapic.PredictionServiceClient', return_value=mock_client):
63+
self.handler.get_request(batch, model, throttle_delay_secs=5, inference_args=None)
64+
request = mock_client.raw_predict.call_args[1]["request"]
65+
self.assertEqual(json.loads(request.http_body.data.decode("utf-8")), expected_payload)
66+
67+
def test_run_inference_parse_response(self):
68+
"""Test parsing of a Triton response into PredictionResult objects."""
69+
batch = [1.0, 2.0]
70+
mock_response = {
71+
"outputs": [
72+
{
73+
"name": "output__0",
74+
"shape": [2, 1],
75+
"datatype": "FP32",
76+
"data": [0.5, 0.6],
77+
}
78+
]
79+
}
80+
model = MagicMock(resource_name="test-resource", deployed_model_id="model-123")
81+
with patch.object(self.handler, 'get_request', return_value=mock_response):
82+
results = self.handler.run_inference(batch, model)
83+
expected_results = [
84+
PredictionResult(example=1.0, inference=[0.5], model_id="model-123"),
85+
PredictionResult(example=2.0, inference=[0.6], model_id="model-123"),
86+
]
87+
self.assertEqual(list(results), expected_results)
88+
89+
def test_run_inference_empty_batch(self):
90+
"""Test that an empty batch returns an empty list."""
91+
batch = []
92+
model = MagicMock()
93+
results = self.handler.run_inference(batch, model)
94+
self.assertEqual(list(results), [])
95+
96+
def test_run_inference_malformed_response(self):
97+
"""Test that a malformed response raises an error."""
98+
batch = [1.0]
99+
mock_response = {"unexpected": "data"}
100+
model = MagicMock()
101+
with patch.object(self.handler, 'get_request', return_value=mock_response):
102+
with self.assertRaises(ValueError) as cm:
103+
list(self.handler.run_inference(batch, model))
104+
self.assertIn("no outputs found", str(cm.exception))
105+
106+
def test_throttling_delays_request(self):
107+
"""Test that the handler delays requests when throttled."""
108+
batch = [1.0]
109+
model = MagicMock(resource_name="test-resource")
110+
self.handler.throttler.throttle_request = MagicMock(side_effect=[True, False])
111+
mock_response = {"outputs": [{"data": [0.5]}]}
112+
113+
with patch('time.sleep') as mock_sleep:
114+
with patch('google.cloud.aiplatform.gapic.PredictionServiceClient') as mock_client:
115+
mock_client.return_value.raw_predict.return_value.data.decode.return_value = json.dumps(mock_response)
116+
self.handler.run_inference(batch, model)
117+
mock_sleep.assert_called_with(5)
118+
self.handler.throttled_secs.inc.assert_called_with(5)
119+
120+
if __name__ == "__main__":
121+
unittest.main()

sdks/python/apache_beam/ml/inference/vertex_ai_inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from typing import Mapping
2525
from typing import Optional
2626
from typing import Sequence
27-
27+
from google.api import httpbody_pb2
2828
from google.api_core.exceptions import ServerError
2929
from google.api_core.exceptions import TooManyRequests
3030
from google.cloud import aiplatform
@@ -346,7 +346,7 @@ def get_request(
346346
batch: Sequence[Any],
347347
model: aiplatform.Endpoint,
348348
throttle_delay_secs: int,
349-
inference_agrs: Optional[Dict[str,Any]]):
349+
inference_args: Optional[Dict[str,Any]]):
350350
while self.throttler.throttle_request(time.time() * MSEC_TO_SEC):
351351
LOGGER.info(
352352
"Delaying request for %d seconds due to previous failures",
@@ -369,7 +369,7 @@ def get_request(
369369
client_options = {"api_endpoint": api_endpoint}
370370
pred_client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
371371
request = aiplatform.gapic.RawPredictRequest(endpoint=model.resource_name,
372-
http_body=aiplatform.gapic.HttpBody(data=body, content_type="application/json"),)
372+
http_body=httpbody_pb2.HttpBody(data=body, content_type="application/json"),)
373373
response = pred_client.raw_predict(request = request)
374374
response_data = json.loads(response.data.decode('utf-8'))
375375
return response_data
@@ -394,6 +394,8 @@ def run_inference(
394394
Returns:
395395
An iterable of Predictions.
396396
"""
397+
if not batch:
398+
return []
397399
prediction = self.get_request(
398400
batch,model,throttle_delay_secs=5,inference_args=inference_args
399401
)

0 commit comments

Comments
 (0)