Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/hopsworks_common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ class PREDICTOR:
# serving tool
SERVING_TOOL_DEFAULT = "DEFAULT"
SERVING_TOOL_KSERVE = "KSERVE"
# vLLM variant
VLLM_VARIANT_VLLM = "VLLM"
VLLM_VARIANT_OMNI = "VLLM_OMNI"


class PREDICTOR_STATE:
Expand Down
2 changes: 1 addition & 1 deletion python/hsml/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from hopsworks_apigen import public
from hopsworks_common import client, usage, util
from hopsworks_common.client.exceptions import ModelServingException
from hsml import predictor as predictor_mod
from hsml.client.exceptions import ModelServingException
from hsml.constants import DEPLOYABLE_COMPONENT, PREDICTOR_STATE
from hsml.core import model_api, serving_api
from hsml.engine import serving_engine
Expand Down
3 changes: 2 additions & 1 deletion python/hsml/engine/local_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import os

from hopsworks_common.core import dataset_api
from hsml import client
from hsml.core import dataset_api, hdfs_api, model_api
from hsml.core import hdfs_api, model_api


class LocalEngine:
Expand Down
20 changes: 19 additions & 1 deletion python/hsml/llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,33 @@
# limitations under the License.
#

from __future__ import annotations

from hsml.constants import MODEL, PREDICTOR
from hsml.predictor import Predictor


class Predictor(Predictor):
"""Configuration for a predictor running with the vLLM backend."""

def __init__(self, **kwargs):
def __init__(
self,
vllm_variant: str | None = None,
vllm_image_tag: str | None = None,
**kwargs,
):
if vllm_variant is None:
vllm_variant = PREDICTOR.VLLM_VARIANT_VLLM
else:
valid_variants = {PREDICTOR.VLLM_VARIANT_VLLM, PREDICTOR.VLLM_VARIANT_OMNI}
if vllm_variant not in valid_variants:
raise ValueError(
f"vLLM variant '{vllm_variant}' is not valid. Possible values are {sorted(valid_variants)}"
)

kwargs["model_framework"] = MODEL.FRAMEWORK_LLM
kwargs["model_server"] = PREDICTOR.MODEL_SERVER_VLLM
kwargs["vllm_variant"] = vllm_variant
kwargs["vllm_image_tag"] = vllm_image_tag

super().__init__(**kwargs)
6 changes: 6 additions & 0 deletions python/hsml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def deploy(
api_protocol: str | None = IE.API_PROTOCOL_REST,
environment: str | None = None,
env_vars: dict | None = None,
vllm_variant: str | None = None,
vllm_image_tag: str | None = None,
) -> deployment.Deployment:
"""Deploy the model.

Expand Down Expand Up @@ -261,6 +263,8 @@ def deploy(
api_protocol: API protocol to be enabled in the deployment (i.e., 'REST' or 'GRPC').
environment: The inference environment to use.
env_vars: Environment variables to set on the predictor.
vllm_variant: vLLM image variant for vLLM deployments. One of `'VLLM'` or `'VLLM_OMNI'`. Ignored for non-vLLM model servers.
vllm_image_tag: vLLM image tag override. `None` uses the cluster default; if set, it should match one of the tags made available by a cluster administrator. Ignored for non-vLLM model servers.

Returns:
The deployment metadata object of a new or existing deployment.
Expand All @@ -286,6 +290,8 @@ def deploy(
api_protocol=api_protocol,
environment=environment,
env_vars=env_vars,
vllm_variant=vllm_variant,
vllm_image_tag=vllm_image_tag,
)

return predictor.deploy()
Expand Down
6 changes: 6 additions & 0 deletions python/hsml/model_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def create_predictor(
environment: str | None = None,
scaling_configuration: PredictorScalingConfig | dict | None = None,
env_vars: dict | None = None,
vllm_variant: str | None = None,
vllm_image_tag: str | None = None,
) -> Predictor:
"""Create a Predictor metadata object.

Expand Down Expand Up @@ -230,6 +232,8 @@ def create_predictor(
environment: The project Python environment to use
scaling_configuration: Scaling configuration for the predictor.
env_vars: Environment variables to set on the predictor.
vllm_variant: vLLM image variant for vLLM deployments. One of `'VLLM'` or `'VLLM_OMNI'`. Ignored for non-vLLM model servers.
vllm_image_tag: vLLM image tag override. `None` uses the cluster default; if set, it should match one of the tags made available by a cluster administrator. Ignored for non-vLLM model servers.

Returns:
The predictor metadata object.
Expand All @@ -251,6 +255,8 @@ def create_predictor(
environment=environment,
scaling_configuration=scaling_configuration,
env_vars=env_vars,
vllm_variant=vllm_variant,
vllm_image_tag=vllm_image_tag,
)

@public
Expand Down
36 changes: 36 additions & 0 deletions python/hsml/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(
project_namespace: str = None,
scaling_configuration: PredictorScalingConfig | dict | Default | None = None,
env_vars: dict[str, str] | None = None,
vllm_variant: str | None = None,
vllm_image_tag: str | None = None,
**kwargs,
):
serving_tool = (
Expand Down Expand Up @@ -126,6 +128,8 @@ def __init__(
self._project_namespace = project_namespace
self._project_name = None
self._env_vars = env_vars
self._vllm_variant = vllm_variant
self._vllm_image_tag = vllm_image_tag

@public
def deploy(self) -> deployment.Deployment:
Expand Down Expand Up @@ -331,6 +335,12 @@ def extract_fields_from_json(cls, json_decamelized):
kwargs["scaling_configuration"] = PredictorScalingConfig.from_json(
json_decamelized
)
kwargs["vllm_variant"] = util.extract_field_from_json(
json_decamelized, "vllm_variant"
)
kwargs["vllm_image_tag"] = util.extract_field_from_json(
json_decamelized, "vllm_image_tag"
)
return kwargs

def update_from_response_json(self, json_dict):
Expand All @@ -357,6 +367,12 @@ def to_dict(self):
"apiProtocol": self._api_protocol,
"projectNamespace": self._project_namespace,
}
if self._model_server == PREDICTOR.MODEL_SERVER_VLLM:
json = {
**json,
"vllmVariant": self._vllm_variant,
"vllmImageTag": self._vllm_image_tag,
}
if self.model_name is not None:
json = {**json, "modelName": self._model_name}
if self.model_path is not None:
Expand Down Expand Up @@ -623,6 +639,26 @@ def project_name(self):
def project_name(self, project_name: str):
self._project_name = project_name

@public
@property
def vllm_variant(self):
"""VLLM image variant for this predictor (VLLM or VLLM_OMNI)."""
return self._vllm_variant

@vllm_variant.setter
def vllm_variant(self, vllm_variant: str):
self._vllm_variant = vllm_variant

@public
@property
def vllm_image_tag(self):
"""VLLM image tag override; None means use the cluster default."""
return self._vllm_image_tag

@vllm_image_tag.setter
def vllm_image_tag(self, vllm_image_tag: str):
self._vllm_image_tag = vllm_image_tag

@public
def get_endpoint_url(self) -> str | None:
"""Get the base endpoint URL for this predictor.
Expand Down
80 changes: 80 additions & 0 deletions python/tests/fixtures/predictor_fixtures.json
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,86 @@
}
}
},
"get_deployment_vllm_kserve_vllm_variant": {
"response": {
"id": 10,
"name": "my_llm",
"description": "vllm deployment",
"version": 1,
"created": "",
"creator": "",
"model_path": "llm_model_path",
"model_name": "my_llm",
"model_version": 1,
"model_framework": "LLM",
"model_server": "VLLM",
"serving_tool": "KSERVE",
"api_protocol": "REST",
"config_file": "config.yaml",
"vllmVariant": "VLLM",
"vllmImageTag": null,
"requested_instances": 0,
"predictor_resources": {
"requested_instances": 0,
"requests": { "cores": 1.0, "memory": 4096, "gpus": 1 },
"limits": { "cores": 2.0, "memory": 8192, "gpus": 1 }
},
"environment_dto": {
"name": "llm-inference-pipeline"
},
"project_namespace": "test",
"predictor_scaling_config": {
"scale_metric": "RPS",
"target": 10,
"min_instances": 0,
"max_instances": 2,
"panic_window_percentage": 10.0,
"stable_window_seconds": 60,
"panic_threshold_percentage": 200.0,
"scale_to_zero_retention_seconds": 0
}
}
},
"get_deployment_vllm_kserve_omni_variant": {
"response": {
"id": 11,
"name": "my_llm_omni",
"description": "vllm omni deployment",
"version": 1,
"created": "",
"creator": "",
"model_path": "llm_model_path",
"model_name": "my_llm_omni",
"model_version": 1,
"model_framework": "LLM",
"model_server": "VLLM",
"serving_tool": "KSERVE",
"api_protocol": "REST",
"config_file": "config.yaml",
"vllmVariant": "VLLM_OMNI",
"vllmImageTag": "v0.14.0",
"requested_instances": 0,
"predictor_resources": {
"requested_instances": 0,
"requests": { "cores": 1.0, "memory": 4096, "gpus": 1 },
"limits": { "cores": 2.0, "memory": 8192, "gpus": 1 }
},
"environment_dto": {
"name": "llm-inference-pipeline"
},
"project_namespace": "test",
"predictor_scaling_config": {
"scale_metric": "RPS",
"target": 10,
"min_instances": 0,
"max_instances": 2,
"panic_window_percentage": 10.0,
"stable_window_seconds": 60,
"panic_threshold_percentage": 200.0,
"scale_to_zero_retention_seconds": 0
}
}
},
"get_deployment_predictor_state": {
"response": {
"available_instances": 1,
Expand Down
2 changes: 2 additions & 0 deletions python/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def test_deploy(self, mocker, backend_fixtures):
api_protocol=p_json["api_protocol"],
environment=p_json["environment_dto"]["name"],
env_vars=None,
vllm_variant=None,
vllm_image_tag=None,
)
mock_predictor.deploy.assert_called_once()

Expand Down
Loading
Loading