1- import json
2- import time
3-
4- import pytest
5- from apache_beam .ml .inference .vertex_ai_inference import VertexAIModelHandlerJSON
6- from apache_beam .ml .inference .vertex_ai_inference import VertexAITritonModelHandler
1+ import unittest
2+ from unittest .mock import MagicMock , patch
73from apache_beam .ml .inference .base import PredictionResult
8-
9- # Define a fake endpoint class to simulate a live AI Platform endpoint.
10- class FakeEndpoint :
11- def __init__ (self ):
12- self .resource_name = "projects/test/locations/us-central1/endpoints/12345"
13- self .deployed_model_id = "deployed_model_1"
14-
15- def list_models (self ):
16- return [{"dummy" : "model" }]
17-
18-
19- # Define a fake PredictionServiceClient to simulate the behavior of the AI Platform client.
20- class FakePredictionServiceClient :
21- def __init__ (self , client_options = None ):
22- self .client_options = client_options
23-
24- def raw_predict (self , request ):
25- # Simulate a response. The true handler expects the response data to be a JSON
26- # string containing an "outputs" field.
27- fake_resp_content = {
28- "outputs" : [{"data" : [["result1" ], ["result2" ]]}]
4+ from apache_beam .ml .inference .vertex_ai_inference import VertexAITritonModelHandler
5+ import json
6+ from google .cloud import aiplatform
7+ from google .cloud .aiplatform .gapic import PredictionServiceClient
8+
9+ class TestVertexAITritonModelHandler (unittest .TestCase ):
10+ def setUp (self ):
11+ """Initialize the handler with test parameters before each test."""
12+ self .handler = VertexAITritonModelHandler (
13+ project_id = "test-project" ,
14+ region = "us-central1" ,
15+ endpoint_name = "test-endpoint" ,
16+ name = "input__0" ,
17+ location = "us-central1" ,
18+ datatype = "FP32" ,
19+ private = False ,
20+ )
21+ self .handler .throttler = MagicMock ()
22+ self .handler .throttler .throttle_request = MagicMock (return_value = False )
23+ self .handler .throttled_secs = MagicMock ()
24+
25+ def test_load_model_public_endpoint (self ):
26+ """Test loading a public endpoint and verifying deployed models."""
27+ mock_endpoint = MagicMock (spec = aiplatform .Endpoint )
28+ mock_endpoint .list_models .return_value = [MagicMock ()]
29+ with patch ('google.cloud.aiplatform.Endpoint' , return_value = mock_endpoint ):
30+ model = self .handler .load_model ()
31+ self .assertEqual (model , mock_endpoint )
32+ mock_endpoint .list_models .assert_called_once ()
33+
34+ def test_load_model_no_deployed_models (self ):
35+ """Test that an endpoint with no deployed models raises ValueError."""
36+ mock_endpoint = MagicMock (spec = aiplatform .Endpoint )
37+ mock_endpoint .list_models .return_value = []
38+ with patch ('google.cloud.aiplatform.Endpoint' , return_value = mock_endpoint ):
39+ with self .assertRaises (ValueError ) as cm :
40+ self .handler .load_model ()
41+ self .assertIn ("no models deployed" , str (cm .exception ))
42+
43+ def test_get_request_payload_scalar (self ):
44+ """Test payload construction for a batch of scalar inputs."""
45+ batch = [1.0 , 2.0 , 3.0 ]
46+ expected_payload = {
47+ "inputs" : [
48+ {
49+ "name" : "input__0" ,
50+ "shape" : [3 , 1 ],
51+ "datatype" : "FP32" ,
52+ "data" : [1.0 , 2.0 , 3.0 ],
53+ }
54+ ]
2955 }
30- class FakeResponse :
31- pass
32- fake_response = FakeResponse ()
33- fake_response .data = json .dumps (fake_resp_content ).encode ("utf-8" )
34- return fake_response
35-
36-
37- @pytest .fixture (autouse = True )
38- def patch_prediction_service_client (monkeypatch ):
39- # Replace the PredictionServiceClient with our fake implementation.
40- monkeypatch .setattr (
41- "apache_beam.ml.inference.vertex_ai_inference.aiplatform.gapic.PredictionServiceClient" ,
42- lambda client_options = None : FakePredictionServiceClient (client_options )
43- )
44-
45-
46- @pytest .fixture (autouse = True )
47- def patch_retrieve_endpoint (monkeypatch ):
48- # Patch the _retrieve_endpoint() to always return our fake endpoint.
49- monkeypatch .setattr (
50- VertexAITritonModelHandler ,
51- "_retrieve_endpoint" ,
52- lambda self : FakeEndpoint ()
53- )
54-
55-
56- @pytest .fixture (autouse = True )
57- def patch_convert_to_result (monkeypatch ):
58- # Override the conversion utility to a simple function for testing.
59- from apache_beam .ml .inference import utils
60- def fake_convert_to_result (batch , predictions , deployed_model_id ):
61- # In our simple conversion, we return a tuple: (input, prediction, deployed_model_id)
62- return [(inp , pred , deployed_model_id ) for inp , pred in zip (batch , predictions )]
63- monkeypatch .setattr (utils , "_convert_to_result" , fake_convert_to_result )
64-
65-
66- def test_get_request (monkeypatch ):
67- """
68- Test that the get_request method constructs the request properly,
69- calls the fake PredictionServiceClient, and returns the expected response data.
70- """
71- # Create a TritonModelHandler instance.
72- handler = VertexAITritonModelHandler (
73- project_id = "test-project" ,
74- region = "us-central1" ,
75- endpoint_name = "12345" ,
76- name = "input_field" ,
77- location = "us-central1" ,
78- datatype = "BYTES"
79- )
80-
81- # Create a fake endpoint to pass into get_request.
82- fake_endpoint = FakeEndpoint ()
83- batch = ["test_input" ]
84-
85- # Call get_request.
86- response_data = handler .get_request (batch , fake_endpoint , throttle_delay_secs = 1 , inference_agrs = {})
87-
88- # Validate the response structure.
89- assert "outputs" in response_data
90- outputs = response_data ["outputs" ]
91- assert isinstance (outputs , list )
92- assert "data" in outputs [0 ]
93- # The fake client returns a list with two prediction results.
94- assert outputs [0 ]["data" ] == [["result1" ], ["result2" ]]
95-
96-
97- def test_run_inference (monkeypatch ):
98- """
99- Test that run_inference converts the raw response into PredictionResult objects.
100- """
101- handler = VertexAITritonModelHandler (
102- project_id = "test-project" ,
103- region = "us-central1" ,
104- endpoint_name = "12345" ,
105- name = "input_field" ,
106- location = "us-central1" ,
107- datatype = "BYTES"
108- )
109-
110- # load_model() is patched via _retrieve_endpoint so it returns our FakeEndpoint.
111- model = handler .load_model ()
112- batch = ["input1" , "input2" ]
113-
114- # Call run_inference and collect the results.
115- results = list (handler .run_inference (batch , model , inference_args = {"payload_config" : {}}))
116-
117- # With our fake response (2 predictions) and a batch of 2 inputs,
118- # the conversion utility should yield 2 PredictionResult tuples.
119- # For our fake_convert_to_result, each result is a tuple: (input, prediction, deployed_model_id).
120- assert len (results ) == len (batch )
121- for inp , pred , deployed_id in results :
122- assert inp in batch
123- assert deployed_id == model .deployed_model_id
56+ model = MagicMock (resource_name = "test-resource" )
57+ mock_client = MagicMock (spec = PredictionServiceClient )
58+ mock_response = MagicMock ()
59+ mock_response .data .decode .return_value = json .dumps ({"outputs" : [{"data" : [0.5 , 0.6 , 0.7 ]}]})
60+ mock_client .raw_predict .return_value = mock_response
61+
62+ with patch ('google.cloud.aiplatform.gapic.PredictionServiceClient' , return_value = mock_client ):
63+ self .handler .get_request (batch , model , throttle_delay_secs = 5 , inference_args = None )
64+ request = mock_client .raw_predict .call_args [1 ]["request" ]
65+ self .assertEqual (json .loads (request .http_body .data .decode ("utf-8" )), expected_payload )
66+
67+ def test_run_inference_parse_response (self ):
68+ """Test parsing of a Triton response into PredictionResult objects."""
69+ batch = [1.0 , 2.0 ]
70+ mock_response = {
71+ "outputs" : [
72+ {
73+ "name" : "output__0" ,
74+ "shape" : [2 , 1 ],
75+ "datatype" : "FP32" ,
76+ "data" : [0.5 , 0.6 ],
77+ }
78+ ]
79+ }
80+ model = MagicMock (resource_name = "test-resource" , deployed_model_id = "model-123" )
81+ with patch .object (self .handler , 'get_request' , return_value = mock_response ):
82+ results = self .handler .run_inference (batch , model )
83+ expected_results = [
84+ PredictionResult (example = 1.0 , inference = [0.5 ], model_id = "model-123" ),
85+ PredictionResult (example = 2.0 , inference = [0.6 ], model_id = "model-123" ),
86+ ]
87+ self .assertEqual (list (results ), expected_results )
88+
89+ def test_run_inference_empty_batch (self ):
90+ """Test that an empty batch returns an empty list."""
91+ batch = []
92+ model = MagicMock ()
93+ results = self .handler .run_inference (batch , model )
94+ self .assertEqual (list (results ), [])
95+
96+ def test_run_inference_malformed_response (self ):
97+ """Test that a malformed response raises an error."""
98+ batch = [1.0 ]
99+ mock_response = {"unexpected" : "data" }
100+ model = MagicMock ()
101+ with patch .object (self .handler , 'get_request' , return_value = mock_response ):
102+ with self .assertRaises (ValueError ) as cm :
103+ list (self .handler .run_inference (batch , model ))
104+ self .assertIn ("no outputs found" , str (cm .exception ))
105+
106+ def test_throttling_delays_request (self ):
107+ """Test that the handler delays requests when throttled."""
108+ batch = [1.0 ]
109+ model = MagicMock (resource_name = "test-resource" )
110+ self .handler .throttler .throttle_request = MagicMock (side_effect = [True , False ])
111+ mock_response = {"outputs" : [{"data" : [0.5 ]}]}
112+
113+ with patch ('time.sleep' ) as mock_sleep :
114+ with patch ('google.cloud.aiplatform.gapic.PredictionServiceClient' ) as mock_client :
115+ mock_client .return_value .raw_predict .return_value .data .decode .return_value = json .dumps (mock_response )
116+ self .handler .run_inference (batch , model )
117+ mock_sleep .assert_called_with (5 )
118+ self .handler .throttled_secs .inc .assert_called_with (5 )
119+
120+ if __name__ == "__main__" :
121+ unittest .main ()
0 commit comments