diff --git a/python/hopsworks_common/constants.py b/python/hopsworks_common/constants.py index 6dbf7a2cc3..48b2f76ce2 100644 --- a/python/hopsworks_common/constants.py +++ b/python/hopsworks_common/constants.py @@ -205,6 +205,9 @@ class PREDICTOR: # serving tool SERVING_TOOL_DEFAULT = "DEFAULT" SERVING_TOOL_KSERVE = "KSERVE" + # vLLM variant + VLLM_VARIANT_VLLM = "VLLM" + VLLM_VARIANT_OMNI = "VLLM_OMNI" class PREDICTOR_STATE: diff --git a/python/hsml/deployment.py b/python/hsml/deployment.py index 2619d0a9d3..cb8c5dd070 100644 --- a/python/hsml/deployment.py +++ b/python/hsml/deployment.py @@ -18,8 +18,8 @@ from hopsworks_apigen import public from hopsworks_common import client, usage, util +from hopsworks_common.client.exceptions import ModelServingException from hsml import predictor as predictor_mod -from hsml.client.exceptions import ModelServingException from hsml.constants import DEPLOYABLE_COMPONENT, PREDICTOR_STATE from hsml.core import model_api, serving_api from hsml.engine import serving_engine diff --git a/python/hsml/engine/local_engine.py b/python/hsml/engine/local_engine.py index 034bf5c69c..273782d476 100644 --- a/python/hsml/engine/local_engine.py +++ b/python/hsml/engine/local_engine.py @@ -16,8 +16,9 @@ import os +from hopsworks_common.core import dataset_api from hsml import client -from hsml.core import dataset_api, hdfs_api, model_api +from hsml.core import hdfs_api, model_api class LocalEngine: diff --git a/python/hsml/llm/predictor.py b/python/hsml/llm/predictor.py index 065c2b1a67..ba5b440a81 100644 --- a/python/hsml/llm/predictor.py +++ b/python/hsml/llm/predictor.py @@ -14,6 +14,8 @@ # limitations under the License. # +from __future__ import annotations + from hsml.constants import MODEL, PREDICTOR from hsml.predictor import Predictor @@ -21,8 +23,24 @@ class Predictor(Predictor): """Configuration for a predictor running with the vLLM backend.""" - def __init__(self, **kwargs): + def __init__( + self, + vllm_variant: str | None = None, + vllm_image_tag: str | None = None, + **kwargs, + ): + if vllm_variant is None: + vllm_variant = PREDICTOR.VLLM_VARIANT_VLLM + else: + valid_variants = {PREDICTOR.VLLM_VARIANT_VLLM, PREDICTOR.VLLM_VARIANT_OMNI} + if vllm_variant not in valid_variants: + raise ValueError( + f"vLLM variant '{vllm_variant}' is not valid. Possible values are {sorted(valid_variants)}" + ) + kwargs["model_framework"] = MODEL.FRAMEWORK_LLM kwargs["model_server"] = PREDICTOR.MODEL_SERVER_VLLM + kwargs["vllm_variant"] = vllm_variant + kwargs["vllm_image_tag"] = vllm_image_tag super().__init__(**kwargs) diff --git a/python/hsml/model.py b/python/hsml/model.py index 523b53e61d..bc78c8142f 100644 --- a/python/hsml/model.py +++ b/python/hsml/model.py @@ -225,6 +225,8 @@ def deploy( api_protocol: str | None = IE.API_PROTOCOL_REST, environment: str | None = None, env_vars: dict | None = None, + vllm_variant: str | None = None, + vllm_image_tag: str | None = None, ) -> deployment.Deployment: """Deploy the model. @@ -261,6 +263,8 @@ def deploy( api_protocol: API protocol to be enabled in the deployment (i.e., 'REST' or 'GRPC'). environment: The inference environment to use. env_vars: Environment variables to set on the predictor. + vllm_variant: vLLM image variant for vLLM deployments. One of `'VLLM'` or `'VLLM_OMNI'`. Ignored for non-vLLM model servers. + vllm_image_tag: vLLM image tag override. `None` uses the cluster default; if set, it should match one of the tags made available by a cluster administrator. Ignored for non-vLLM model servers. Returns: The deployment metadata object of a new or existing deployment. @@ -286,6 +290,8 @@ def deploy( api_protocol=api_protocol, environment=environment, env_vars=env_vars, + vllm_variant=vllm_variant, + vllm_image_tag=vllm_image_tag, ) return predictor.deploy() diff --git a/python/hsml/model_serving.py b/python/hsml/model_serving.py index ff3fa4d806..eed9144a49 100644 --- a/python/hsml/model_serving.py +++ b/python/hsml/model_serving.py @@ -188,6 +188,8 @@ def create_predictor( environment: str | None = None, scaling_configuration: PredictorScalingConfig | dict | None = None, env_vars: dict | None = None, + vllm_variant: str | None = None, + vllm_image_tag: str | None = None, ) -> Predictor: """Create a Predictor metadata object. @@ -230,6 +232,8 @@ def create_predictor( environment: The project Python environment to use scaling_configuration: Scaling configuration for the predictor. env_vars: Environment variables to set on the predictor. + vllm_variant: vLLM image variant for vLLM deployments. One of `'VLLM'` or `'VLLM_OMNI'`. Ignored for non-vLLM model servers. + vllm_image_tag: vLLM image tag override. `None` uses the cluster default; if set, it should match one of the tags made available by a cluster administrator. Ignored for non-vLLM model servers. Returns: The predictor metadata object. @@ -251,6 +255,8 @@ def create_predictor( environment=environment, scaling_configuration=scaling_configuration, env_vars=env_vars, + vllm_variant=vllm_variant, + vllm_image_tag=vllm_image_tag, ) @public diff --git a/python/hsml/predictor.py b/python/hsml/predictor.py index 5ead96cb0d..6b77e24681 100644 --- a/python/hsml/predictor.py +++ b/python/hsml/predictor.py @@ -78,6 +78,8 @@ def __init__( project_namespace: str = None, scaling_configuration: PredictorScalingConfig | dict | Default | None = None, env_vars: dict[str, str] | None = None, + vllm_variant: str | None = None, + vllm_image_tag: str | None = None, **kwargs, ): serving_tool = ( @@ -126,6 +128,8 @@ def __init__( self._project_namespace = project_namespace self._project_name = None self._env_vars = env_vars + self._vllm_variant = vllm_variant + self._vllm_image_tag = vllm_image_tag @public def deploy(self) -> deployment.Deployment: @@ -331,6 +335,12 @@ def extract_fields_from_json(cls, json_decamelized): kwargs["scaling_configuration"] = PredictorScalingConfig.from_json( json_decamelized ) + kwargs["vllm_variant"] = util.extract_field_from_json( + json_decamelized, "vllm_variant" + ) + kwargs["vllm_image_tag"] = util.extract_field_from_json( + json_decamelized, "vllm_image_tag" + ) return kwargs def update_from_response_json(self, json_dict): @@ -357,6 +367,12 @@ def to_dict(self): "apiProtocol": self._api_protocol, "projectNamespace": self._project_namespace, } + if self._model_server == PREDICTOR.MODEL_SERVER_VLLM: + json = { + **json, + "vllmVariant": self._vllm_variant, + "vllmImageTag": self._vllm_image_tag, + } if self.model_name is not None: json = {**json, "modelName": self._model_name} if self.model_path is not None: @@ -623,6 +639,26 @@ def project_name(self): def project_name(self, project_name: str): self._project_name = project_name + @public + @property + def vllm_variant(self): + """VLLM image variant for this predictor (VLLM or VLLM_OMNI).""" + return self._vllm_variant + + @vllm_variant.setter + def vllm_variant(self, vllm_variant: str): + self._vllm_variant = vllm_variant + + @public + @property + def vllm_image_tag(self): + """VLLM image tag override; None means use the cluster default.""" + return self._vllm_image_tag + + @vllm_image_tag.setter + def vllm_image_tag(self, vllm_image_tag: str): + self._vllm_image_tag = vllm_image_tag + @public def get_endpoint_url(self) -> str | None: """Get the base endpoint URL for this predictor. diff --git a/python/tests/fixtures/predictor_fixtures.json b/python/tests/fixtures/predictor_fixtures.json index 542b2094f2..c7837b0ebb 100644 --- a/python/tests/fixtures/predictor_fixtures.json +++ b/python/tests/fixtures/predictor_fixtures.json @@ -542,6 +542,86 @@ } } }, + "get_deployment_vllm_kserve_vllm_variant": { + "response": { + "id": 10, + "name": "my_llm", + "description": "vllm deployment", + "version": 1, + "created": "", + "creator": "", + "model_path": "llm_model_path", + "model_name": "my_llm", + "model_version": 1, + "model_framework": "LLM", + "model_server": "VLLM", + "serving_tool": "KSERVE", + "api_protocol": "REST", + "config_file": "config.yaml", + "vllmVariant": "VLLM", + "vllmImageTag": null, + "requested_instances": 0, + "predictor_resources": { + "requested_instances": 0, + "requests": { "cores": 1.0, "memory": 4096, "gpus": 1 }, + "limits": { "cores": 2.0, "memory": 8192, "gpus": 1 } + }, + "environment_dto": { + "name": "llm-inference-pipeline" + }, + "project_namespace": "test", + "predictor_scaling_config": { + "scale_metric": "RPS", + "target": 10, + "min_instances": 0, + "max_instances": 2, + "panic_window_percentage": 10.0, + "stable_window_seconds": 60, + "panic_threshold_percentage": 200.0, + "scale_to_zero_retention_seconds": 0 + } + } + }, + "get_deployment_vllm_kserve_omni_variant": { + "response": { + "id": 11, + "name": "my_llm_omni", + "description": "vllm omni deployment", + "version": 1, + "created": "", + "creator": "", + "model_path": "llm_model_path", + "model_name": "my_llm_omni", + "model_version": 1, + "model_framework": "LLM", + "model_server": "VLLM", + "serving_tool": "KSERVE", + "api_protocol": "REST", + "config_file": "config.yaml", + "vllmVariant": "VLLM_OMNI", + "vllmImageTag": "v0.14.0", + "requested_instances": 0, + "predictor_resources": { + "requested_instances": 0, + "requests": { "cores": 1.0, "memory": 4096, "gpus": 1 }, + "limits": { "cores": 2.0, "memory": 8192, "gpus": 1 } + }, + "environment_dto": { + "name": "llm-inference-pipeline" + }, + "project_namespace": "test", + "predictor_scaling_config": { + "scale_metric": "RPS", + "target": 10, + "min_instances": 0, + "max_instances": 2, + "panic_window_percentage": 10.0, + "stable_window_seconds": 60, + "panic_threshold_percentage": 200.0, + "scale_to_zero_retention_seconds": 0 + } + } + }, "get_deployment_predictor_state": { "response": { "available_instances": 1, diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 92cdab45e1..4d4a06f0ce 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -240,6 +240,8 @@ def test_deploy(self, mocker, backend_fixtures): api_protocol=p_json["api_protocol"], environment=p_json["environment_dto"]["name"], env_vars=None, + vllm_variant=None, + vllm_image_tag=None, ) mock_predictor.deploy.assert_called_once() diff --git a/python/tests/test_predictor.py b/python/tests/test_predictor.py index 0799a85bdd..de00ab48d1 100644 --- a/python/tests/test_predictor.py +++ b/python/tests/test_predictor.py @@ -1175,6 +1175,179 @@ def test_env_vars_wire_round_trip(self, mocker, backend_fixtures): assert kwargs["env_vars"] == {"FOO": "bar", "K": "V=with=eq"} + # vLLM variant round-trip + + def test_vllm_variant_vllm_round_trip(self, mocker, backend_fixtures): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + p_json = backend_fixtures["predictor"][ + "get_deployment_vllm_kserve_vllm_variant" + ]["response"] + + # Act + p = predictor.Predictor.from_response_json(p_json) + serialized = p.to_dict() + p2 = predictor.Predictor.from_response_json(serialized) + + # Assert + assert p.vllm_variant == "VLLM" + assert p.vllm_image_tag is None + assert p2.vllm_variant == p.vllm_variant + assert p2.vllm_image_tag == p.vllm_image_tag + assert serialized["vllmVariant"] == "VLLM" + assert serialized["vllmImageTag"] is None + + def test_vllm_variant_omni_round_trip(self, mocker, backend_fixtures): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + p_json = backend_fixtures["predictor"][ + "get_deployment_vllm_kserve_omni_variant" + ]["response"] + + # Act + p = predictor.Predictor.from_response_json(p_json) + serialized = p.to_dict() + p2 = predictor.Predictor.from_response_json(serialized) + + # Assert + assert p.vllm_variant == "VLLM_OMNI" + assert p.vllm_image_tag == "v0.14.0" + assert p2.vllm_variant == p.vllm_variant + assert p2.vllm_image_tag == p.vllm_image_tag + assert serialized["vllmVariant"] == "VLLM_OMNI" + assert serialized["vllmImageTag"] == "v0.14.0" + + def test_llm_predictor_default_variant(self, mocker): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + from hsml.llm.predictor import Predictor as LLMPredictor + + # Act + p = LLMPredictor(name="my_llm") + + # Assert + assert p.vllm_variant == PREDICTOR.VLLM_VARIANT_VLLM + assert p.vllm_image_tag is None + + def test_llm_predictor_omni_variant(self, mocker): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + from hsml.llm.predictor import Predictor as LLMPredictor + + # Act + p = LLMPredictor( + name="my_llm", + vllm_variant=PREDICTOR.VLLM_VARIANT_OMNI, + vllm_image_tag="v0.14.0", + ) + + # Assert + assert p.vllm_variant == PREDICTOR.VLLM_VARIANT_OMNI + assert p.vllm_image_tag == "v0.14.0" + + def test_llm_predictor_invalid_variant_raises(self, mocker): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + from hsml.llm.predictor import Predictor as LLMPredictor + + # Act + Assert + with pytest.raises(ValueError) as e_info: + LLMPredictor(name="my_llm", vllm_variant="INVALID") + + assert "is not valid" in str(e_info.value) + + def test_llm_predictor_default_to_dict_wire_format(self, mocker): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + from hsml.llm.predictor import Predictor as LLMPredictor + + p = LLMPredictor(name="my_llm") + + # Act + serialized = p.to_dict() + + # Assert + assert serialized["vllmVariant"] == PREDICTOR.VLLM_VARIANT_VLLM + assert serialized["vllmImageTag"] is None + + def test_non_vllm_to_dict_does_not_emit_vllm_keys(self, mocker): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + + p = predictor.Predictor( + name="my_model", + model_server=PREDICTOR.MODEL_SERVER_PYTHON, + model_name="my_model", + model_version=1, + model_framework=MODEL.FRAMEWORK_SKLEARN, + ) + + # Act + serialized = p.to_dict() + + # Assert + assert "vllmVariant" not in serialized + assert "vllmImageTag" not in serialized + + def test_for_model_propagates_variant_and_image_tag(self, mocker): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + + captured_kwargs = {} + + def fake_get_predictor_for_model(model, **kwargs): + captured_kwargs.update(kwargs) + from hsml.llm.predictor import Predictor as LLMPredictor + + return LLMPredictor( + name=kwargs["model_name"], + vllm_variant=kwargs.get("vllm_variant", PREDICTOR.VLLM_VARIANT_VLLM), + vllm_image_tag=kwargs.get("vllm_image_tag"), + ) + + mocker.patch( + "hopsworks_common.util.get_predictor_for_model", + side_effect=fake_get_predictor_for_model, + ) + + class MockModel: + name = "my_llm" + version = 1 + model_path = "llm_model_path" + + mock_model = MockModel() + + # Act + p = predictor.Predictor.for_model( + mock_model, + vllm_variant=PREDICTOR.VLLM_VARIANT_OMNI, + vllm_image_tag="v0.14.0", + ) + + # Assert + assert captured_kwargs["vllm_variant"] == PREDICTOR.VLLM_VARIANT_OMNI + assert captured_kwargs["vllm_image_tag"] == "v0.14.0" + assert p.vllm_variant == PREDICTOR.VLLM_VARIANT_OMNI + assert p.vllm_image_tag == "v0.14.0" + + def test_update_from_response_json_preserves_vllm_fields( + self, mocker, backend_fixtures + ): + # Arrange + self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT) + p_json = backend_fixtures["predictor"][ + "get_deployment_vllm_kserve_omni_variant" + ]["response"] + + p = predictor.Predictor.from_response_json(p_json) + + # Mutate in-place then refresh from the same JSON + p.update_from_response_json(p_json) + + # Assert both fields survive the refresh + assert p.vllm_variant == "VLLM_OMNI" + assert p.vllm_image_tag == "v0.14.0" + # auxiliary methods def _mock_serving_variables(