-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Add Triton Inference Server Support #34252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
33225ce
3ee8cf6
dcd470d
281df71
a7b6518
0aca56a
e64a490
970f2f1
011982e
06b804d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,3 +256,113 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): | |
|
||
def batch_elements_kwargs(self) -> Mapping[str, Any]: | ||
return self._batching_kwargs | ||
|
||
|
||
class VertexAITritonModelHandler(ModelHandler[Any, | ||
PredictionResult, | ||
aiplatform.Endpoint]): | ||
""" | ||
A custom model handler for Vertex AI endpoints hosting Triton Inference Servers. | ||
It constructs a payload that Triton expects and calls the raw predict endpoint. | ||
""" | ||
|
||
def __init__(self, | ||
project_id: str, | ||
region: str, | ||
endpoint_name: str, | ||
location: str, | ||
payload_config: Optional[Dict[str,Any]] = None, | ||
private: bool = False, | ||
|
||
): | ||
self.project_id = project_id | ||
self.region = region | ||
self.endpoint_name = endpoint_name | ||
self.endpoint_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints/{endpoint_name}:predict" | ||
self.is_private = private | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are there distinctions between public and private triton endpoints? |
||
self.location = location | ||
self.payload_config = payload_config if payload_config else {} | ||
|
||
# Configure AdaptiveThrottler and throttling metrics for client-side | ||
# throttling behavior. | ||
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing | ||
# for more details. | ||
self.throttled_secs = Metrics.counter( | ||
VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds") | ||
SaumilPatel03 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.throttler = AdaptiveThrottler( | ||
window_ms=1, bucket_ms=1, overload_ratio=2) | ||
|
||
def load_model(self) -> aiplatform.Endpoint: | ||
"""Loads the Endpoint object used to build and send prediction request to | ||
Vertex AI. | ||
""" | ||
# Check to make sure the endpoint is still active since pipeline | ||
# construction time | ||
ep = self._retrieve_endpoint( | ||
self.endpoint_name, self.location, self.is_private) | ||
return ep | ||
|
||
def _retrieve_endpoint( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I cannot find any sort of discussion around public versus private triton endpoints, but as I've said before the aiplatform.Endpoint classes aren't what you should be using anyway. |
||
self, endpoint_id: str, | ||
location: str, | ||
is_private: bool) -> aiplatform.Endpoint: | ||
"""Retrieves an AI Platform endpoint and queries it for liveness/deployed | ||
models. | ||
|
||
Args: | ||
endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve. | ||
is_private: a boolean indicating if the Vertex AI endpoint is a private | ||
endpoint | ||
Returns: | ||
An aiplatform.Endpoint object | ||
Raises: | ||
ValueError: if endpoint is inactive or has no models deployed to it. | ||
""" | ||
if is_private: | ||
endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint( | ||
endpoint_name=endpoint_id, location=location) | ||
LOGGER.debug("Treating endpoint %s as private", endpoint_id) | ||
else: | ||
endpoint = aiplatform.Endpoint( | ||
endpoint_name=endpoint_id, location=location) | ||
LOGGER.debug("Treating endpoint %s as public", endpoint_id) | ||
|
||
try: | ||
mod_list = endpoint.list_models() | ||
except Exception as e: | ||
raise ValueError( | ||
"Failed to contact endpoint %s, got exception: %s", endpoint_id, e) | ||
|
||
if len(mod_list) == 0: | ||
raise ValueError("Endpoint %s has no models deployed to it.", endpoint_id) | ||
|
||
return endpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do triton endpoints function correctly in this way? |
||
|
||
def run_inference( | ||
self, | ||
batch: Sequence[Any], | ||
model: aiplatform.Endpoint, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not align with usage, an endpoint object is not the model name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jrmccluskey Can you explain why model parameter should not be aiplatform Endpoint. Since load_model returns an Endpoint object, it seems logical to use it for Vertex AI’s raw_predict method (e.g., with Triton). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. raw_predict isn't using an endpoint object, it uses a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're still deploying the model to a vertex endpoint, but that object's abstraction in the SDK is not useful here |
||
inference_args: Optional[Dict[str, Any]] = None | ||
) -> Iterable[PredictionResult]: | ||
""" | ||
Sends a prediction request with the Triton-specific payload structure. | ||
""" | ||
|
||
config = self.payload_config.copy() | ||
if inference_args: | ||
config.update(inference_args) | ||
SaumilPatel03 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
payload = { | ||
"inputs": [ | ||
{ | ||
"name": config.get("name", "name"), | ||
"shape": config.get("shape", [1, 1]), | ||
SaumilPatel03 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"datatype": config.get("datatype", "BYTES"), | ||
"data": batch, | ||
} | ||
] | ||
} | ||
client = aiplatform.gapic.PredictionServiceClient() | ||
predict_response = client.predict(model_name=model, instances=[payload]) | ||
for inp, pred in zip(batch, predict_response.predictions): | ||
yield PredictionResult(inp, pred) | ||
SaumilPatel03 marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.