Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Triton Inference Server Support #34252

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):

def batch_elements_kwargs(self) -> Mapping[str, Any]:
return self._batching_kwargs


class VertexAITritonModelHandler(ModelHandler[Any,
PredictionResult,
aiplatform.Endpoint]):
"""
A custom model handler for Vertex AI endpoints hosting Triton Inference Servers.
It constructs a payload that Triton expects and calls the raw predict endpoint.
"""

def __init__(self,
project_id: str,
region: str,
endpoint_name: str,
location: str,
payload_config: Optional[Dict[str,Any]] = None,
private: bool = False,

):
self.project_id = project_id
self.region = region
self.endpoint_name = endpoint_name
self.endpoint_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict"
self.is_private = private
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there distinctions between public and private triton endpoints?

self.location = location
self.payload_config = payload_config if payload_config else {}

# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
# for more details.
self.throttled_secs = Metrics.counter(
VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)

def load_model(self) -> aiplatform.Endpoint:
"""Loads the Endpoint object used to build and send prediction request to
Vertex AI.
"""
# Check to make sure the endpoint is still active since pipeline
# construction time
ep = self._retrieve_endpoint(
self.endpoint_name, self.location, self.is_private)
return ep

def _retrieve_endpoint(
self, endpoint_id: str,
location: str,
is_private: bool) -> aiplatform.Endpoint:
"""Retrieves an AI Platform endpoint and queries it for liveness/deployed
models.

Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
is_private: a boolean indicating if the Vertex AI endpoint is a private
endpoint
Returns:
An aiplatform.Endpoint object
Raises:
ValueError: if endpoint is inactive or has no models deployed to it.
"""
if is_private:
endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
endpoint_name=endpoint_id, location=location)
LOGGER.debug("Treating endpoint %s as private", endpoint_id)
else:
endpoint = aiplatform.Endpoint(
endpoint_name=endpoint_id, location=location)
LOGGER.debug("Treating endpoint %s as public", endpoint_id)

try:
mod_list = endpoint.list_models()
except Exception as e:
raise ValueError(
"Failed to contact endpoint %s, got exception: %s", endpoint_id, e)

if len(mod_list) == 0:
raise ValueError("Endpoint %s has no models deployed to it.", endpoint_id)

return endpoint
Comment on lines +308 to +342
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do triton endpoints function correctly in this way?


def run_inference(
self,
batch: Sequence[Any],
model: aiplatform.Endpoint,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not align with usage, an endpoint object is not the model name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jrmccluskey Can you explain why model parameter should not be aiplatform Endpoint. Since load_model returns an Endpoint object, it seems logical to use it for Vertex AI’s raw_predict method (e.g., with Triton).

inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""
Sends a prediction request with the Triton-specific payload structure.
"""

config = self.payload_config.copy()
if inference_args:
config.update(inference_args)

payload = {
"inputs": [
{
"name": config.get("name", "name"),
"shape": config.get("shape", [1, 1]),
"datatype": config.get("datatype", "BYTES"),
"data": batch,
}
]
}
client = aiplatform.gapic.PredictionServiceClient()
predict_response = client.predict(model_name=model, instances=[payload])
for inp, pred in zip(batch, predict_response.predictions):
yield PredictionResult(inp, pred)
Loading