diff --git a/sdks/python/apache_beam/ml/inference/test_triton_model_handler.py b/sdks/python/apache_beam/ml/inference/test_triton_model_handler.py new file mode 100644 index 000000000000..86075e421fb0 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/test_triton_model_handler.py @@ -0,0 +1,121 @@ +import unittest +from unittest.mock import MagicMock, patch +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.vertex_ai_inference import VertexAITritonModelHandler +import json +from google.cloud import aiplatform +from google.cloud.aiplatform.gapic import PredictionServiceClient + +class TestVertexAITritonModelHandler(unittest.TestCase): + def setUp(self): + """Initialize the handler with test parameters before each test.""" + self.handler = VertexAITritonModelHandler( + project_id="test-project", + region="us-central1", + endpoint_name="test-endpoint", + name="input__0", + location="us-central1", + datatype="FP32", + private=False, + ) + self.handler.throttler = MagicMock() + self.handler.throttler.throttle_request = MagicMock(return_value=False) + self.handler.throttled_secs = MagicMock() + + def test_load_model_public_endpoint(self): + """Test loading a public endpoint and verifying deployed models.""" + mock_endpoint = MagicMock(spec=aiplatform.Endpoint) + mock_endpoint.list_models.return_value = [MagicMock()] + with patch('google.cloud.aiplatform.Endpoint', return_value=mock_endpoint): + model = self.handler.load_model() + self.assertEqual(model, mock_endpoint) + mock_endpoint.list_models.assert_called_once() + + def test_load_model_no_deployed_models(self): + """Test that an endpoint with no deployed models raises ValueError.""" + mock_endpoint = MagicMock(spec=aiplatform.Endpoint) + mock_endpoint.list_models.return_value = [] + with patch('google.cloud.aiplatform.Endpoint', return_value=mock_endpoint): + with self.assertRaises(ValueError) as cm: + self.handler.load_model() + self.assertIn("no models deployed", str(cm.exception)) + + def test_get_request_payload_scalar(self): + """Test payload construction for a batch of scalar inputs.""" + batch = [1.0, 2.0, 3.0] + expected_payload = { + "inputs": [ + { + "name": "input__0", + "shape": [3, 1], + "datatype": "FP32", + "data": [1.0, 2.0, 3.0], + } + ] + } + model = MagicMock(resource_name="test-resource") + mock_client = MagicMock(spec=PredictionServiceClient) + mock_response = MagicMock() + mock_response.data.decode.return_value = json.dumps({"outputs": [{"data": [0.5, 0.6, 0.7]}]}) + mock_client.raw_predict.return_value = mock_response + + with patch('google.cloud.aiplatform.gapic.PredictionServiceClient', return_value=mock_client): + self.handler.get_request(batch, model, throttle_delay_secs=5, inference_args=None) + request = mock_client.raw_predict.call_args[1]["request"] + self.assertEqual(json.loads(request.http_body.data.decode("utf-8")), expected_payload) + + def test_run_inference_parse_response(self): + """Test parsing of a Triton response into PredictionResult objects.""" + batch = [1.0, 2.0] + mock_response = { + "outputs": [ + { + "name": "output__0", + "shape": [2, 1], + "datatype": "FP32", + "data": [0.5, 0.6], + } + ] + } + model = MagicMock(resource_name="test-resource", deployed_model_id="model-123") + with patch.object(self.handler, 'get_request', return_value=mock_response): + results = self.handler.run_inference(batch, model) + expected_results = [ + PredictionResult(example=1.0, inference=[0.5], model_id="model-123"), + PredictionResult(example=2.0, inference=[0.6], model_id="model-123"), + ] + self.assertEqual(list(results), expected_results) + + def test_run_inference_empty_batch(self): + """Test that an empty batch returns an empty list.""" + batch = [] + model = MagicMock() + results = self.handler.run_inference(batch, model) + self.assertEqual(list(results), []) + + def test_run_inference_malformed_response(self): + """Test that a malformed response raises an error.""" + batch = [1.0] + mock_response = {"unexpected": "data"} + model = MagicMock() + with patch.object(self.handler, 'get_request', return_value=mock_response): + with self.assertRaises(ValueError) as cm: + list(self.handler.run_inference(batch, model)) + self.assertIn("no outputs found", str(cm.exception)) + + def test_throttling_delays_request(self): + """Test that the handler delays requests when throttled.""" + batch = [1.0] + model = MagicMock(resource_name="test-resource") + self.handler.throttler.throttle_request = MagicMock(side_effect=[True, False]) + mock_response = {"outputs": [{"data": [0.5]}]} + + with patch('time.sleep') as mock_sleep: + with patch('google.cloud.aiplatform.gapic.PredictionServiceClient') as mock_client: + mock_client.return_value.raw_predict.return_value.data.decode.return_value = json.dumps(mock_response) + self.handler.run_inference(batch, model) + mock_sleep.assert_called_with(5) + self.handler.throttled_secs.inc.assert_called_with(5) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 4c4163accfb9..852e5b462d4e 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -17,13 +17,14 @@ import logging import time +import json from typing import Any from typing import Dict from typing import Iterable from typing import Mapping from typing import Optional from typing import Sequence - +from google.api import httpbody_pb2 from google.api_core.exceptions import ServerError from google.api_core.exceptions import TooManyRequests from google.cloud import aiplatform @@ -256,3 +257,152 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs + + +class VertexAITritonModelHandler(ModelHandler[Any, + PredictionResult, + aiplatform.Endpoint]): + """ + A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. + It constructs a payload that Triton expects and calls the raw predict endpoint. + """ + + def __init__(self, + project_id: str, + region: str, + endpoint_name: str, + name:str, + location: str, + datatype: str, + payload_config: Optional[Dict[str,Any]] = None, + private: bool = False, + ): + self.project_id = project_id + self.region = region + self.endpoint_name = endpoint_name + self.input_name = name + self.endpoint_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict" + self.is_private = private + self.location = location + self.datatype = datatype + self.payload_config = payload_config if payload_config else {} + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAITritonModelHandler, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + def load_model(self) -> aiplatform.Endpoint: + """Loads the Endpoint object used to build and send prediction request to + Vertex AI. + """ + # Check to make sure the endpoint is still active since pipeline + # construction time + ep = self._retrieve_endpoint( + self.endpoint_name, self.location, self.is_private) + return ep + + def _retrieve_endpoint( + self, endpoint_id: str, + location: str, + is_private: bool) -> aiplatform.Endpoint: + """Retrieves an AI Platform endpoint and queries it for liveness/deployed + models. + + Args: + endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve. + is_private: a boolean indicating if the Vertex AI endpoint is a private + endpoint + Returns: + An aiplatform.Endpoint object + Raises: + ValueError: if endpoint is inactive or has no models deployed to it. + """ + if is_private: + endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint( + endpoint_name=endpoint_id, location=location) + LOGGER.debug("Treating endpoint %s as private", endpoint_id) + else: + endpoint = aiplatform.Endpoint( + endpoint_name=endpoint_id, location=location) + LOGGER.debug("Treating endpoint %s as public", endpoint_id) + + try: + mod_list = endpoint.list_models() + except Exception as e: + raise ValueError( + "Failed to contact endpoint %s, got exception: %s", endpoint_id, e) + + if len(mod_list) == 0: + raise ValueError("Endpoint %s has no models deployed to it.", endpoint_id) + + return endpoint + + def get_request( + self, + batch: Sequence[Any], + model: aiplatform.Endpoint, + throttle_delay_secs: int, + inference_args: Optional[Dict[str,Any]]): + while self.throttler.throttle_request(time.time() * MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + triton_request = { + "inputs": [ + { + "name": self.input_name, + "shape": [len(batch),1], + "datatype": self.datatype, + "data": batch + } + ] + } + body = json.dumps(triton_request).encode("utf-8") + api_endpoint = f"{self.region}-aiplatform.googleapis.com" + client_options = {"api_endpoint": api_endpoint} + pred_client = aiplatform.gapic.PredictionServiceClient(client_options=client_options) + request = aiplatform.gapic.RawPredictRequest(endpoint=model.resource_name, + http_body=httpbody_pb2.HttpBody(data=body, content_type="application/json"),) + response = pred_client.raw_predict(request = request) + response_data = json.loads(response.data.decode('utf-8')) + return response_data + + def run_inference( + self, + batch: Sequence[Any], + model: aiplatform.Endpoint, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """ + Sends a prediction request with the Triton-specific payload structure. + + Args: + batch: a sequence of any values to be passed to the Vertex AI endpoint. + Should be encoded as the model expects. + model: an aiplatform.Endpoint object configured to access the desired + model. + inference_args: any additional arguments to send as part of the + prediction request. + + Returns: + An iterable of Predictions. + """ + if not batch: + return [] + prediction = self.get_request( + batch,model,throttle_delay_secs=5,inference_args=inference_args + ) + if "outputs" not in prediction or not prediction["outputs"]: + raise ValueError("Unexpected response format from Triton server: no outputs found.") + output_data = prediction["outputs"][0]["data"] + predictions = [output_data[i:i+1] for i in range(0, len(output_data), 1)] + + return utils._convert_to_result( + batch, predictions, model.deployed_model_id)