diff --git a/docs/book/component-guide/model-deployers/vertex.md b/docs/book/component-guide/model-deployers/vertex.md new file mode 100644 index 00000000000..0bd4ab175b5 --- /dev/null +++ b/docs/book/component-guide/model-deployers/vertex.md @@ -0,0 +1,182 @@ +# Vertex AI Model Deployer + +[Vertex AI](https://cloud.google.com/vertex-ai) provides managed infrastructure for deploying machine learning models at scale. The Vertex AI Model Deployer in ZenML allows you to deploy models to Vertex AI endpoints, providing a scalable and fully managed solution for model serving. + +## When to use it? + +Use the Vertex AI Model Deployer when: + +- You are leveraging Google Cloud Platform (GCP) and wish to integrate with its native ML serving infrastructure. +- You need enterprise-grade model serving capabilities complete with autoscaling and GPU acceleration. +- You require a fully managed solution that abstracts away the operational overhead of serving models. +- You need to deploy models directly from your Vertex AI Model Registry—or even from other registries or artifacts. +- You want seamless integration with GCP services like Cloud Logging, IAM, and VPC. + +This deployer is especially useful for production deployments, high-availability serving, and dynamic scaling based on workloads. + +{% hint style="info" %} +For best results, the Vertex AI Model Deployer works with a Vertex AI Model Registry in your ZenML stack. This allows you to register models with detailed metadata and configuration and then deploy a specific version seamlessly. +{% endhint %} + +## How to deploy it? + +The Vertex AI Model Deployer is enabled via the ZenML GCP integration. First, install the integration: + +```shell +zenml integration install gcp -y +``` + +### Authentication and Service Connector Configuration + +The deployer requires proper GCP authentication. The recommended approach is to use the ZenML Service Connector: + +```shell +# Register the service connector with a service account key +zenml service-connector register vertex_deployer_connector \ + --type gcp \ + --auth-method=service-account \ + --project_id= \ + --service_account_json=@vertex-deployer-sa.json \ + --resource-type gcp-generic + +# Register the model deployer and connect it to the service connector +zenml model-deployer register vertex_deployer \ + --flavor=vertex \ + --location=us-central1 \ + --connector vertex_deployer_connector +``` + +{% hint style="info" %} +The service account used for deployment must have the following permissions: +- `Vertex AI User` to enable model deployments +- `Vertex AI Service Agent` for model endpoint management +- `Storage Object Viewer` if the model artifacts reside in Google Cloud Storage +{% endhint %} + +## How to use it + +A complete usage example is available in the [ZenML Examples repository](https://github.com/zenml-io/zenml-projects/tree/main/vertex-registry-and-deployer). + +### Deploying a Model in a Pipeline + +Below is an example of a deployment step that uses the updated configuration options. In this example, the deployment configuration supports: + +- **Model versioning**: Explicitly provide the model version (using the full resource name from the model registry). +- **Display name and Sync mode**: Fields such as `display_name` (for a friendly endpoint name) and `sync` (to wait for deployment completion) are now available. +- **Traffic configuration**: Route a certain percentage (e.g., 100%) of traffic to this deployment. +- **Advanced options**: You can still specify custom container settings, resource specifications (including GPU options), and explanation configuration via shared classes from `vertex_base_config.py`. + +```python +from typing_extensions import Annotated +from zenml import ArtifactConfig, get_step_context, step +from zenml.client import Client +from zenml.integrations.gcp.services.vertex_deployment import ( + VertexDeploymentConfig, + VertexDeploymentService, +) + +@step(enable_cache=False) +def model_deployer( + model_registry_uri: str, + is_promoted: bool = False, +) -> Annotated[ + VertexDeploymentService, + ArtifactConfig(name="vertex_deployment", is_deployment_artifact=True), +]: + """Model deployer step. + + Args: + model_registry_uri: The full resource name of the model in the registry. + is_promoted: Flag indicating if the model is promoted to production. + + Returns: + The deployed model service. + """ + if not is_promoted: + # Skip deployment if the model is not promoted. + return None + else: + zenml_client = Client() + current_model = get_step_context().model + model_deployer = zenml_client.active_stack.model_deployer + + # Create deployment configuration with advanced options. + vertex_deployment_config = VertexDeploymentConfig( + location="europe-west1", + name=current_model.name, # Unique endpoint name in Vertex AI. + display_name="zenml-vertex-quickstart", + model_name=model_registry_uri, # Fully qualified model name (from model registry). + model_version=current_model.version, # Specify the model version explicitly. + description="An example of deploying a model using the Vertex AI Model Deployer", + sync=True, # Wait for deployment to complete before proceeding. + traffic_percentage=100, # Route 100% of traffic to this model version. + # (Optional) Advanced configurations: + # container=VertexAIContainerSpec( + # image_uri="your-custom-image:latest", + # ports=[8080], + # env={"ENV_VAR": "value"} + # ), + # resources=VertexAIResourceSpec( + # accelerator_type="NVIDIA_TESLA_T4", + # accelerator_count=1, + # machine_type="n1-standard-4", + # min_replica_count=1, + # max_replica_count=3, + # ), + # explanation=VertexAIExplanationSpec( + # metadata={"method": "integrated-gradients"}, + # parameters={"num_integral_steps": 50} + # ) + ) + + service = model_deployer.deploy_model( + config=vertex_deployment_config, + service_type=VertexDeploymentService.SERVICE_TYPE, + ) + + return service +``` + +### Configuration Options + +The Vertex AI Model Deployer leverages a comprehensive configuration system defined in the shared base configuration and deployer-specific settings: + +- **Basic Settings:** + - `location`: The GCP region for deployment (e.g., "us-central1" or "europe-west1"). + - `name`: Unique identifier for the deployed endpoint. + - `display_name`: A human-friendly name for the endpoint. + - `model_name`: The fully qualified model name from the model registry. + - `model_version`: The version of the model to deploy. + - `description`: A textual description of the deployment. + - `sync`: A flag to indicate whether the deployment should wait until completion. + - `traffic_percentage`: The percentage of incoming traffic to route to this deployment. + +- **Container and Resource Configuration:** + - Configurations provided via VertexAIContainerSpec allow you to specify a custom serving container image, HTTP routes (`predict_route`, `health_route`), environment variables, and port exposure. + - VertexAIResourceSpec lets you override the default machine type, number of replicas, and even GPU options. + +- **Advanced Settings:** + - Service account, network configuration, and customer-managed encryption keys. + - Model explanation settings via `VertexAIExplanationSpec` if you need integrated model interpretability. + +These options are defined across the Vertex AI Base Config and the deployer–specific configuration in VertexModelDeployerFlavor. + +### Limitations and Considerations + +1. **Stack Requirements:** + - It is recommended to pair the deployer with a Vertex AI Model Registry in your stack. + - Compatible with both local and remote orchestrators. + - Requires valid GCP credentials and permissions. + +2. **Authentication:** + - Best practice is to use service connectors for secure and managed authentication. + - Supports multiple authentication methods (service accounts, local credentials). + +3. **Costs:** + - Vertex AI endpoints will incur costs based on machine type and uptime. + - Utilize autoscaling (via configured `min_replica_count` and `max_replica_count`) to manage cost. + +4. **Region Consistency:** + - Ensure that the model and deployment are created in the same GCP region. + +For more details, please refer to the [SDK docs](https://sdkdocs.zenml.io). \ No newline at end of file diff --git a/docs/book/component-guide/model-registries/vertex.md b/docs/book/component-guide/model-registries/vertex.md new file mode 100644 index 00000000000..4a401a93595 --- /dev/null +++ b/docs/book/component-guide/model-registries/vertex.md @@ -0,0 +1,203 @@ +# Vertex AI Model Registry + +[Vertex AI](https://cloud.google.com/vertex-ai) is Google Cloud's unified ML platform that helps you build, deploy, and scale ML models. The Vertex AI Model Registry is a centralized repository for managing your ML models throughout their lifecycle. With ZenML's Vertex AI Model Registry integration, you can register model versions—with extended configuration options—track metadata, and seamlessly deploy your models using Vertex AI's managed infrastructure. + +## When would you want to use it? + +You should consider using the Vertex AI Model Registry when: + +- You're already using Google Cloud Platform (GCP) and want to leverage its native ML infrastructure. +- You need enterprise-grade model management with fine-grained access control. +- You want to track model lineage and metadata in a centralized location. +- You're building ML pipelines that integrate with other Vertex AI services. +- You need to deploy models with custom configurations such as defined container images, resource specifications, and additional metadata. + +This registry is particularly useful in scenarios where you: +- Build production ML pipelines that require deployment to Vertex AI endpoints. +- Manage multiple versions of models across development, staging, and production. +- Need to register model versions with detailed configuration for robust deployment. + +{% hint style="warning" %} +**Important:** The Vertex AI Model Registry implementation only supports the model **version** interface—not the model interface. This means that you cannot directly register, update, or delete models; you only have operations for model versions. A model container is automatically created with the first version, and subsequent uploads with the same display name create new versions. +{% endhint %} + +## How do you deploy it? + +The Vertex AI Model Registry flavor is enabled through the ZenML GCP integration. First, install the integration: + +```shell +zenml integration install gcp -y +``` + +### Authentication and Service Connector Configuration + +Vertex AI requires proper GCP authentication. The recommended configuration is via the ZenML Service Connector, which supports both service-account-based authentication and local gcloud credentials. + +1. **Using a GCP Service Connector with a service account (Recommended):** + ```shell + # Register the service connector with a service account key + zenml service-connector register vertex_registry_connector \ + --type gcp \ + --auth-method=service-account \ + --project_id= \ + --service_account_json=@vertex-registry-sa.json \ + --resource-type gcp-generic + + # Register the model registry + zenml model-registry register vertex_registry \ + --flavor=vertex \ + --location=us-central1 + + # Connect the model registry to the service connector + zenml model-registry connect vertex_registry --connector vertex_registry_connector + ``` +2. **Using local gcloud credentials:** + ```shell + # Register the model registry using local gcloud auth + zenml model-registry register vertex_registry \ + --flavor=vertex \ + --location=us-central1 + ``` + +{% hint style="info" %} +The service account needs the following permissions: +- `Vertex AI User` role for creating and managing model versions. +- `Storage Object Viewer` role if accessing models stored in Google Cloud Storage. +{% endhint %} + +## How do you use it? + +### Registering Models inside a Pipeline with Extended Configuration + +The Vertex AI Model Registry supports extended configuration options via the `VertexAIModelConfig` class. This means you can specify additional details for your deployments such as: + +- **Container configuration**: Use the `VertexAIContainerSpec` to define a custom serving container (e.g., specifying the `image_uri`, `predict_route`, `health_route`, and exposed ports). +- **Resource configuration**: Use the `VertexAIResourceSpec` to specify compute resources like `machine_type`, `min_replica_count`, and `max_replica_count`. +- **Additional metadata and labels**: Annotate your model registrations with pipeline details, stage information, and custom labels. + +Below is an example of how you might register a model version in your ZenML pipeline: + +```python +from typing_extensions import Annotated + +from zenml import ArtifactConfig, get_step_context, step +from zenml.client import Client +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIContainerSpec, + VertexAIModelConfig, + VertexAIResourceSpec, +) +from zenml.logger import get_logger +from zenml.model_registries.base_model_registry import ( + ModelRegistryModelMetadata, +) + +logger = get_logger(__name__) + + +@step(enable_cache=False) +def model_register( + is_promoted: bool = False, +) -> Annotated[str, ArtifactConfig(name="model_registry_uri")]: + """Model registration step. + + Registers a model version in the Vertex AI Model Registry with extended configuration + and returns the full resource name of the registered model. + + Extended configuration includes settings for container, resources, and metadata which can then be reused in + subsequent model deployments. + """ + if is_promoted: + # Get the current model from the step context + current_model = get_step_context().model + + client = Client() + model_registry = client.active_stack.model_registry + # Create an extended model configuration using Vertex AI base settings + model_config = VertexAIModelConfig( + location="europe-west1", + container=VertexAIContainerSpec( + image_uri="europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest", + predict_route="predict", + health_route="health", + ports=[8080], + ), + resources=VertexAIResourceSpec( + machine_type="n1-standard-4", + min_replica_count=1, + max_replica_count=1, + ), + labels={"env": "production"}, + description="Extended model configuration for Vertex AI", + ) + + # Register the model version with the extended configuration as metadata + model_version = model_registry.register_model_version( + name=current_model.name, + version=str(current_model.version), + model_source_uri=current_model.get_model_artifact("sklearn_classifier").uri, + description="ZenML model version registered with extended configuration", + metadata=ModelRegistryModelMetadata( + zenml_pipeline_name=get_step_context().pipeline.name, + zenml_pipeline_run_uuid=str(get_step_context().pipeline_run.id), + zenml_step_name=get_step_context().step_run.name, + ), + config=model_config, + ) + logger.info(f"Model version {model_version.version} registered in Model Registry") + + # Return the full resource name of the registered model + return model_version.registered_model.name + else: + return "" +``` + +### Working with Model Versions + +Since the Vertex AI Model Registry supports only version-level operations, here are some commands to manage model versions: + +```shell +# List all model versions +zenml model-registry models list-versions + +# Get details of a specific model version +zenml model-registry models get-version -v + +# Delete a model version +zenml model-registry models delete-version -v +``` + +### Configuration Options + +The Vertex AI Model Registry accepts several configuration options, now enriched with extended settings: + +- **location**: The GCP region where your resources will be created (e.g., "us-central1" or "europe-west1"). +- **project_id**: (Optional) A GCP project ID override. +- **credentials**: (Optional) GCP credentials configuration. +- **container**: (Optional) Detailed container settings (defined via `VertexAIContainerSpec`) for the model's serving container such as: + - `image_uri` + - `predict_route` + - `health_route` + - `ports` +- **resources**: (Optional) Compute resource settings (using `VertexAIResourceSpec`) like `machine_type`, `min_replica_count`, and `max_replica_count`. +- **labels** and **metadata**: Additional annotation data for organizing and tracking your model versions. + +### Key Differences from Other Model Registries + +1. **Version-Only Interface**: Vertex AI only supports version-level operations for model registration. +2. **Authentication**: Uses GCP service connectors and local credentials integrated via ZenML. +3. **Extended Configuration**: Register model versions with detailed settings for container, resources, and metadata through `VertexAIModelConfig`. +4. **Managed Service**: As a fully managed service, Vertex AI handles infrastructure management while you focus on your ML models. + +## Limitations + +- The methods `register_model()`, `update_model()`, and `delete_model()` are not implemented; you can only work with model versions. +- It is recommended to specify a serving container image URI rather than rely on the default scikit-learn container to ensure compatibility with Vertex AI endpoints. +- All models registered through this integration are automatically labeled with `managed_by="zenml"` for consistent tracking. + +For more detailed information, check out the [SDK docs](https://sdkdocs.zenml.io/0.80.1/integration_code_docs/integrations-gcp.html#zenml.integrations.gcp). + +
+ ZenML Scarf +
ZenML in action
+
\ No newline at end of file diff --git a/src/zenml/cli/model_registry.py b/src/zenml/cli/model_registry.py index c326dfcd9a3..fe99601a01c 100644 --- a/src/zenml/cli/model_registry.py +++ b/src/zenml/cli/model_registry.py @@ -18,6 +18,7 @@ import click +from zenml import __version__ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli from zenml.enums import StackComponentType @@ -642,7 +643,7 @@ def register_model_version( # Parse metadata metadata = dict(metadata) if metadata else {} registered_metadata = ModelRegistryModelMetadata(**dict(metadata)) - registered_metadata.zenml_version = zenml_version + registered_metadata.zenml_version = zenml_version or __version__ registered_metadata.zenml_run_name = zenml_run_name registered_metadata.zenml_pipeline_name = zenml_pipeline_name registered_metadata.zenml_step_name = zenml_step_name diff --git a/src/zenml/integrations/gcp/__init__.py b/src/zenml/integrations/gcp/__init__.py index a14427ca421..8a3788c699a 100644 --- a/src/zenml/integrations/gcp/__init__.py +++ b/src/zenml/integrations/gcp/__init__.py @@ -34,6 +34,11 @@ GCP_VERTEX_ORCHESTRATOR_FLAVOR = "vertex" GCP_VERTEX_STEP_OPERATOR_FLAVOR = "vertex" +# Model deployer constants +VERTEX_MODEL_REGISTRY_FLAVOR = "vertex" +VERTEX_MODEL_DEPLOYER_FLAVOR = "vertex" +VERTEX_SERVICE_ARTIFACT = "vertex_deployment_service" + # Service connector constants GCP_CONNECTOR_TYPE = "gcp" GCP_RESOURCE_TYPE = "gcp-generic" @@ -75,6 +80,8 @@ def flavors(cls) -> List[Type[Flavor]]: VertexExperimentTrackerFlavor, VertexOrchestratorFlavor, VertexStepOperatorFlavor, + VertexModelDeployerFlavor, + VertexModelRegistryFlavor, ) return [ @@ -83,4 +90,6 @@ def flavors(cls) -> List[Type[Flavor]]: VertexExperimentTrackerFlavor, VertexOrchestratorFlavor, VertexStepOperatorFlavor, + VertexModelRegistryFlavor, + VertexModelDeployerFlavor, ] diff --git a/src/zenml/integrations/gcp/flavors/__init__.py b/src/zenml/integrations/gcp/flavors/__init__.py index e70f4937594..b78b574a80e 100644 --- a/src/zenml/integrations/gcp/flavors/__init__.py +++ b/src/zenml/integrations/gcp/flavors/__init__.py @@ -33,6 +33,14 @@ VertexStepOperatorConfig, VertexStepOperatorFlavor, ) +from zenml.integrations.gcp.flavors.vertex_model_deployer_flavor import ( + VertexModelDeployerConfig, + VertexModelDeployerFlavor, +) +from zenml.integrations.gcp.flavors.vertex_model_registry_flavor import ( + VertexAIModelRegistryConfig, + VertexModelRegistryFlavor, +) __all__ = [ "GCPArtifactStoreFlavor", @@ -45,4 +53,8 @@ "VertexOrchestratorConfig", "VertexStepOperatorFlavor", "VertexStepOperatorConfig", + "VertexModelDeployerFlavor", + "VertexModelDeployerConfig", + "VertexModelRegistryFlavor", + "VertexAIModelRegistryConfig", ] diff --git a/src/zenml/integrations/gcp/flavors/vertex_base_config.py b/src/zenml/integrations/gcp/flavors/vertex_base_config.py new file mode 100644 index 00000000000..e2872411ba6 --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/vertex_base_config.py @@ -0,0 +1,199 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Shared configuration classes for Vertex AI components.""" + +from typing import Any, Dict, Optional, Sequence + +from pydantic import BaseModel, Field + +from zenml.config.base_settings import BaseSettings + + +class VertexAIContainerSpec(BaseModel): + """Container specification for Vertex AI models and endpoints.""" + + image_uri: Optional[str] = Field( + None, description="Docker image URI for model serving" + ) + command: Optional[Sequence[str]] = Field( + None, description="Container command to run" + ) + args: Optional[Sequence[str]] = Field( + None, description="Container command arguments" + ) + env: Optional[Dict[str, str]] = Field( + None, description="Environment variables" + ) + ports: Optional[Sequence[int]] = Field( + None, description="Container ports to expose" + ) + predict_route: Optional[str] = Field( + None, description="HTTP path for prediction requests" + ) + health_route: Optional[str] = Field( + None, description="HTTP path for health check requests" + ) + + +class VertexAIResourceSpec(BaseModel): + """Resource specification for Vertex AI deployments.""" + + machine_type: Optional[str] = Field( + None, description="Compute instance machine type" + ) + accelerator_type: Optional[str] = Field( + None, description="Hardware accelerator type" + ) + accelerator_count: Optional[int] = Field( + None, description="Number of accelerators" + ) + min_replica_count: Optional[int] = Field( + 1, description="Minimum number of replicas" + ) + max_replica_count: Optional[int] = Field( + 1, description="Maximum number of replicas" + ) + + +class VertexAIExplanationSpec(BaseModel): + """Explanation configuration for Vertex AI models.""" + + metadata: Optional[Dict[str, Any]] = Field( + None, description="Explanation metadata" + ) + parameters: Optional[Dict[str, Any]] = Field( + None, description="Explanation parameters" + ) + + +class VertexAIBaseConfig(BaseModel): + """Base configuration shared by Vertex AI components. + + Reference: + - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models + - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints + """ + + # Basic settings + location: str = Field( + "us-central1", description="GCP region for Vertex AI resources" + ) + project_id: Optional[str] = Field( + None, description="Optional project ID override" + ) + + # Container configuration + container: Optional[VertexAIContainerSpec] = Field( + None, description="Container configuration" + ) + + # Resource configuration + resources: Optional[VertexAIResourceSpec] = Field( + None, description="Resource configuration" + ) + + # Service configuration + service_account: Optional[str] = Field( + None, description="Service account email" + ) + network: Optional[str] = Field(None, description="VPC network") + + # Security + encryption_spec_key_name: Optional[str] = Field( + None, description="Customer-managed encryption key" + ) + + # Monitoring and logging + enable_access_logging: Optional[bool] = Field( + None, description="Enable access logging" + ) + disable_container_logging: Optional[bool] = Field( + None, description="Disable container logging" + ) + + # Model explanation + explanation: Optional[VertexAIExplanationSpec] = Field( + None, description="Model explanation configuration" + ) + + # Labels and metadata + labels: Optional[Dict[str, str]] = Field( + None, description="Resource labels" + ) + metadata: Optional[Dict[str, str]] = Field( + None, description="Custom metadata" + ) + + +class VertexAIModelConfig(VertexAIBaseConfig): + """Configuration specific to Vertex AI Models.""" + + # Model metadata + display_name: Optional[str] = None + description: Optional[str] = None + version_description: Optional[str] = None + version_aliases: Optional[Sequence[str]] = None + + # Model artifacts + artifact_uri: Optional[str] = None + model_source_spec: Optional[Dict[str, Any]] = None + + # Model versioning + is_default_version: Optional[bool] = None + + # Model formats + supported_deployment_resources_types: Optional[Sequence[str]] = None + supported_input_storage_formats: Optional[Sequence[str]] = None + supported_output_storage_formats: Optional[Sequence[str]] = None + + # Training metadata + training_pipeline_display_name: Optional[str] = None + training_pipeline_id: Optional[str] = None + + # Model optimization + model_source_info: Optional[Dict[str, str]] = None + original_model_info: Optional[Dict[str, str]] = None + containerized_model_optimization: Optional[Dict[str, Any]] = None + + +class VertexAIEndpointConfig(VertexAIBaseConfig): + """Configuration specific to Vertex AI Endpoints.""" + + # Endpoint metadata + display_name: Optional[str] = None + description: Optional[str] = None + + # Traffic configuration + traffic_split: Optional[Dict[str, int]] = None + traffic_percentage: Optional[int] = 0 + + # Autoscaling + autoscaling_target_cpu_utilization: Optional[float] = None + autoscaling_target_accelerator_duty_cycle: Optional[float] = None + + # Deployment + sync: Optional[bool] = False + deploy_request_timeout: Optional[int] = None + existing_endpoint: Optional[str] = None + + +class VertexAIBaseSettings(BaseSettings): + """Base settings for Vertex AI components.""" + + location: str = Field( + "us-central1", description="Default GCP region for Vertex AI resources" + ) + project_id: Optional[str] = Field( + None, description="Optional project ID override" + ) diff --git a/src/zenml/integrations/gcp/flavors/vertex_model_deployer_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_model_deployer_flavor.py new file mode 100644 index 00000000000..afd249385f3 --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/vertex_model_deployer_flavor.py @@ -0,0 +1,132 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Vertex AI model deployer flavor.""" + +from typing import TYPE_CHECKING, Optional, Type + +from zenml.integrations.gcp import ( + GCP_RESOURCE_TYPE, + VERTEX_MODEL_DEPLOYER_FLAVOR, +) +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIBaseSettings, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsConfigMixin, +) +from zenml.model_deployers.base_model_deployer import ( + BaseModelDeployerConfig, + BaseModelDeployerFlavor, +) +from zenml.models.v2.misc.service_connector_type import ( + ServiceConnectorRequirements, +) + +if TYPE_CHECKING: + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + +class VertexModelDeployerConfig( + BaseModelDeployerConfig, + GoogleCredentialsConfigMixin, + VertexAIBaseSettings, +): + """Configuration for the Vertex AI model deployer. + + This configuration combines: + - Base model deployer configuration + - Google Cloud authentication + - Vertex AI Base configuration + """ + + +class VertexModelDeployerFlavor(BaseModelDeployerFlavor): + """Vertex AI model deployer flavor.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return VERTEX_MODEL_DEPLOYER_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + resource_type=GCP_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png" + + @property + def config_class(self) -> Type[VertexModelDeployerConfig]: + """Returns `VertexModelDeployerConfig` config class. + + Returns: + The config class. + """ + return VertexModelDeployerConfig + + @property + def implementation_class(self) -> Type["VertexModelDeployer"]: + """Implementation class for this flavor. + + Returns: + The implementation class. + """ + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + return VertexModelDeployer diff --git a/src/zenml/integrations/gcp/flavors/vertex_model_registry_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_model_registry_flavor.py new file mode 100644 index 00000000000..5055d3ae5dc --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/vertex_model_registry_flavor.py @@ -0,0 +1,130 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""VertexAI model registry flavor.""" + +from typing import TYPE_CHECKING, Optional, Type + +from zenml.integrations.gcp import ( + GCP_RESOURCE_TYPE, + VERTEX_MODEL_REGISTRY_FLAVOR, +) +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIBaseSettings, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsConfigMixin, +) +from zenml.model_registries.base_model_registry import ( + BaseModelRegistryConfig, + BaseModelRegistryFlavor, +) +from zenml.models import ServiceConnectorRequirements + +if TYPE_CHECKING: + from zenml.integrations.gcp.model_registries import ( + VertexAIModelRegistry, + ) + + +class VertexAIModelRegistryConfig( + BaseModelRegistryConfig, + GoogleCredentialsConfigMixin, + VertexAIBaseSettings, +): + """Configuration for the VertexAI model registry. + + This configuration combines: + - Base model registry configuration + - Google Cloud authentication + - Vertex AI Base configuration + """ + + +class VertexModelRegistryFlavor(BaseModelRegistryFlavor): + """Model registry flavor for VertexAI models.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return VERTEX_MODEL_REGISTRY_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + resource_type=GCP_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png" + + @property + def config_class(self) -> Type[VertexAIModelRegistryConfig]: + """Returns `VertexAIModelRegistryConfig` config class. + + Returns: + The config class. + """ + return VertexAIModelRegistryConfig + + @property + def implementation_class(self) -> Type["VertexAIModelRegistry"]: + """Implementation class for this flavor. + + Returns: + The implementation class. + """ + from zenml.integrations.gcp.model_registries import ( + VertexAIModelRegistry, + ) + + return VertexAIModelRegistry diff --git a/src/zenml/integrations/gcp/model_deployers/__init__.py b/src/zenml/integrations/gcp/model_deployers/__init__.py new file mode 100644 index 00000000000..203f57c096f --- /dev/null +++ b/src/zenml/integrations/gcp/model_deployers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the Vertex AI model deployers.""" + +from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( # noqa + VertexModelDeployer, +) + +__all__ = ["VertexModelDeployer"] diff --git a/src/zenml/integrations/gcp/model_deployers/vertex_model_deployer.py b/src/zenml/integrations/gcp/model_deployers/vertex_model_deployer.py new file mode 100644 index 00000000000..d31e5bc6ee7 --- /dev/null +++ b/src/zenml/integrations/gcp/model_deployers/vertex_model_deployer.py @@ -0,0 +1,248 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Vertex AI Model Deployer.""" + +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, cast +from uuid import UUID + +from google.cloud import aiplatform + +from zenml.analytics.enums import AnalyticsEvent +from zenml.analytics.utils import track_handler +from zenml.client import Client +from zenml.enums import StackComponentType +from zenml.integrations.gcp.flavors.vertex_model_deployer_flavor import ( + VertexModelDeployerConfig, + VertexModelDeployerFlavor, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsMixin, +) +from zenml.integrations.gcp.model_registries.vertex_model_registry import ( + VertexAIModelRegistry, +) +from zenml.integrations.gcp.services.vertex_deployment import ( + VertexDeploymentConfig, + VertexDeploymentService, +) +from zenml.logger import get_logger +from zenml.model_deployers import BaseModelDeployer +from zenml.model_deployers.base_model_deployer import ( + DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + BaseModelDeployerFlavor, +) +from zenml.services import BaseService, ServiceConfig +from zenml.stack.stack import Stack +from zenml.stack.stack_validator import StackValidator + +logger = get_logger(__name__) + + +class VertexModelDeployer(BaseModelDeployer, GoogleCredentialsMixin): + """Vertex AI endpoint model deployer.""" + + NAME: ClassVar[str] = "Vertex AI" + FLAVOR: ClassVar[Type["BaseModelDeployerFlavor"]] = ( + VertexModelDeployerFlavor + ) + + @property + def config(self) -> VertexModelDeployerConfig: + """Returns the `VertexModelDeployerConfig` config. + + Returns: + The configuration. + """ + return cast(VertexModelDeployerConfig, self._config) + + def _init_vertex_client( + self, + credentials: Optional[Any] = None, + ) -> None: + """Initialize Vertex AI client with proper credentials. + + Args: + credentials: Optional credentials to use + """ + if not credentials: + credentials, project_id = self._get_authentication() + + # Initialize with per-instance credentials + aiplatform.init( + project=project_id, + location=self.config.location, + credentials=credentials, + ) + + @property + def validator(self) -> Optional[StackValidator]: + """Validates that the stack contains a Vertex AI model registry. + + Returns: + A StackValidator instance. + """ + + def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]: + """Validates stack requirements. + + Args: + stack: The stack to validate. + + Returns: + A tuple of (is_valid, error_message). + """ + model_registry = stack.model_registry + if not isinstance(model_registry, VertexAIModelRegistry): + return False, ( + "The Vertex AI model deployer requires a Vertex AI model " + "registry to be present in the stack. Please add a Vertex AI " + "model registry to the stack." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.MODEL_REGISTRY, + }, + custom_validation_function=_validate_stack_requirements, + ) + + def _create_deployment_service( + self, id: UUID, timeout: int, config: VertexDeploymentConfig + ) -> VertexDeploymentService: + """Creates a new VertexAIDeploymentService. + + Args: + id: the UUID of the model to be deployed + timeout: timeout in seconds for deployment operations + config: deployment configuration + + Returns: + The VertexDeploymentService instance + """ + # Create service instance + service = VertexDeploymentService(uuid=id, config=config) + logger.info("Creating Vertex AI deployment service with ID %s", id) + + # Start the service + service.start(timeout=timeout) + return service + + def perform_deploy_model( + self, + id: UUID, + config: ServiceConfig, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Deploy a model to Vertex AI. + + Args: + id: the UUID of the service to be created + config: deployment configuration + timeout: timeout for deployment operations + + Returns: + The deployment service instance + """ + with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler: + config = cast(VertexDeploymentConfig, config) + + # Create and start deployment service + service = self._create_deployment_service( + id=id, config=config, timeout=timeout + ) + + # Track analytics + client = Client() + stack = client.active_stack + stack_metadata = { + component_type.value: component.flavor + for component_type, component in stack.components.items() + } + analytics_handler.metadata = { + "store_type": client.zen_store.type.value, + **stack_metadata, + } + + return service + + def perform_stop_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> BaseService: + """Stop a Vertex AI deployment service. + + Args: + service: The service to stop + timeout: Timeout for stop operation + force: Whether to force stop + + Returns: + The stopped service + """ + service.stop(timeout=timeout, force=force) + return service + + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Start a Vertex AI deployment service. + + Args: + service: The service to start + timeout: Timeout for start operation + + Returns: + The started service + """ + service.start(timeout=timeout) + return service + + def perform_delete_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> None: + """Delete a Vertex AI deployment service. + + Args: + service: The service to delete + timeout: Timeout for delete operation + force: Whether to force delete + """ + service = cast(VertexDeploymentService, service) + service.stop(timeout=timeout, force=force) + + @staticmethod + def get_model_server_info( # type: ignore[override] + service_instance: "VertexDeploymentService", + ) -> Dict[str, Optional[str]]: + """Get information about the deployed model server. + + Args: + service_instance: The deployment service instance + + Returns: + Dict containing server information + """ + return { + "prediction_url": service_instance.get_prediction_url(), + "status": service_instance.status.state.value, + } diff --git a/src/zenml/integrations/gcp/model_registries/__init__.py b/src/zenml/integrations/gcp/model_registries/__init__.py new file mode 100644 index 00000000000..672c7c19619 --- /dev/null +++ b/src/zenml/integrations/gcp/model_registries/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the Vertex AI model registry.""" + +from zenml.integrations.gcp.model_registries.vertex_model_registry import ( + VertexAIModelRegistry +) + +__all__ = ["VertexAIModelRegistry"] diff --git a/src/zenml/integrations/gcp/model_registries/vertex_model_registry.py b/src/zenml/integrations/gcp/model_registries/vertex_model_registry.py new file mode 100644 index 00000000000..07503775148 --- /dev/null +++ b/src/zenml/integrations/gcp/model_registries/vertex_model_registry.py @@ -0,0 +1,837 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Vertex AI model registry integration for ZenML.""" + +import base64 +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, cast + +from google.cloud import aiplatform + +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIModelConfig, +) +from zenml.integrations.gcp.flavors.vertex_model_registry_flavor import ( + VertexAIModelRegistryConfig, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsMixin, +) +from zenml.integrations.gcp.utils import sanitize_vertex_label +from zenml.logger import get_logger +from zenml.model_registries.base_model_registry import ( + BaseModelRegistry, + ModelRegistryModelMetadata, + ModelVersionStage, + RegisteredModel, + RegistryModelVersion, +) + +logger = get_logger(__name__) + +# Constants for Vertex AI limitations +MAX_LABEL_COUNT = 64 +MAX_LABEL_KEY_LENGTH = 63 +MAX_LABEL_VALUE_LENGTH = 63 +MAX_DISPLAY_NAME_LENGTH = 128 + + +# Helper function to safely get values from metadata dict +def _get_metadata_value( + metadata: Dict[str, Any], key: str, default: Any = None +) -> Any: + """Safely retrieves a value from a dictionary.""" + return metadata.get(key, default) + + +class VertexAIModelRegistry(BaseModelRegistry, GoogleCredentialsMixin): + """Register models using Vertex AI.""" + + @property + def config(self) -> VertexAIModelRegistryConfig: + """Returns the config of the model registry. + + Returns: + The configuration. + """ + return cast(VertexAIModelRegistryConfig, self._config) + + + def _prepare_labels( + self, + metadata: Optional[Dict[str, str]] = None, + stage: Optional[ModelVersionStage] = None, + ) -> Dict[str, str]: + """Prepare labels for Vertex AI model. + + Args: + metadata: Optional metadata to include as labels + stage: Optional model version stage + + Returns: + Dictionary of sanitized labels + """ + labels = {} + + # Add base labels + labels["managed_by"] = "zenml" + # Add stage if provided + if stage: + labels["stage"] = sanitize_vertex_label(stage.value) + + # Process metadata if provided + if metadata: + for key, value in metadata.items(): + # Sanitize both key and value + sanitized_key = sanitize_vertex_label(str(key)) + sanitized_value = sanitize_vertex_label(str(value)) + # Only add if both key and value are valid + if sanitized_key and sanitized_value: + labels[sanitized_key] = sanitized_value + + # Ensure we don't exceed 64 labels + if len(labels) > 64: + # Keep essential labels and truncate the rest + essential_labels = { + k: labels[k] for k in ["managed_by", "stage"] if k in labels + } + # Add remaining labels up to limit + remaining_slots = 64 - len(essential_labels) + other_labels = { + k: v + for i, (k, v) in enumerate(labels.items()) + if k not in essential_labels and i < remaining_slots + } + labels = {**essential_labels, **other_labels} + + return labels + + def _get_model_id(self, name: str) -> str: + """Get the full Vertex AI model ID. + + Args: + name: Model name + + Returns: + str: Full model ID in format: projects/{project}/locations/{location}/models/{model} + """ + _, project_id = self._get_authentication() + model_id = f"projects/{project_id}/locations/{self.config.location}/models/{name}" + return model_id + + def _get_model_version_id(self, model_id: str, version: str) -> str: + """Get the full Vertex AI model version ID. + + Args: + model_id: Full model ID + version: Version string + + Returns: + str: Full model version ID in format: {model_id}/versions/{version} + """ + model_version_id = f"{model_id}/versions/{version}" + return model_version_id + + def _init_vertex_model( + self, name: str, version: Optional[str] = None + ) -> Optional[aiplatform.Model]: + """Initialize a single Vertex AI model with proper credentials. + + This method returns one Vertex AI model based on the given name (and optional version). + + Args: + name: The model name. + version: The model version (optional). + + Returns: + A single Vertex AI model instance or None if initialization fails. + """ + credentials, project_id = self._get_authentication() + location = self.config.location + kwargs = { + "location": location, + "project": project_id, + "credentials": credentials, + } + + if name.startswith("projects/"): + kwargs["model_name"] = name + else: + # Attempt to find an existing model by display_name + existing_models = aiplatform.Model.list( + filter=f"display_name={name}", + project=self.config.project or project_id, + location=location, + ) + if existing_models: + kwargs["model_name"] = existing_models[0].resource_name + else: + model_id = self._get_model_id(name) + if version: + model_id = self._get_model_version_id(model_id, version) + kwargs["model_name"] = model_id + try: + return aiplatform.Model(**kwargs) + except Exception as e: + logger.warning(f"Failed to initialize model: {e}") + return None + + def register_model( + self, + name: str, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> RegisteredModel: + """Register a model to the Vertex AI model registry. + + Args: + name: The name of the model. + description: The description of the model. + metadata: The metadata of the model. + + Raises: + NotImplementedError: Vertex AI does not support registering models, you can only register model versions, skipping model registration... + + """ + raise NotImplementedError( + "Vertex AI does not support registering models, you can only register model versions, skipping model registration..." + ) + + def delete_model( + self, + name: str, + ) -> None: + """Delete a model and all of its versions from the Vertex AI model registry. + + Args: + name: The name of the model. + + Raises: + RuntimeError: if model deletion fails + """ + try: + model = self._init_vertex_model(name=name) + if isinstance(model, aiplatform.Model): + model.delete() + logger.info(f"Deleted model '{name}' and all its versions.") + except Exception as e: + raise RuntimeError(f"Failed to delete model: {str(e)}") + + def update_model( + self, + name: str, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + remove_metadata: Optional[List[str]] = None, + ) -> RegisteredModel: + """Update a model in the Vertex AI model registry. + + Args: + name: The name of the model. + description: The description of the model. + metadata: The metadata of the model. + remove_metadata: The metadata to remove from the model. + + Raises: + NotImplementedError: Vertex AI does not support updating models, you can only update model versions, skipping model registration... + """ + raise NotImplementedError( + "Vertex AI does not support updating models, you can only update model versions, skipping model registration..." + ) + + def get_model(self, name: str) -> RegisteredModel: + """Get a model from the Vertex AI model registry by name without needing a version. + + Args: + name: The name of the model. + + Returns: + The registered model. + + Raises: + RuntimeError: if model retrieval fails + """ + try: + # Fetch by display_name, and use unique labels to ensure multi-tenancy + model = aiplatform.Model(display_name=name) + except Exception as e: + raise RuntimeError(f"Failed to get model: {str(e)}") + return RegisteredModel( + name=model.display_name, + description=model.description, + metadata=model.labels, + ) + + def list_models( + self, + name: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> List[RegisteredModel]: + """List models in the Vertex AI model registry. + + Args: + name: The name of the model. + metadata: The metadata of the model. + + Returns: + The registered models. + + Raises: + RuntimeError: If the models are not found + """ + credentials, project_id = self._get_authentication() + location = self.config.location + # Always filter with ZenML-specific labels (including deployer id for multi-tenancy) + filter_expr = "labels.managed_by=zenml" + + if name: + filter_expr += f" AND display_name={name}" + if metadata: + for key, value in metadata.items(): + filter_expr += f" AND labels.{key}={value}" + try: + all_models = aiplatform.Model.list( + project=project_id, + location=location, + filter=filter_expr, + credentials=credentials, + ) + # Deduplicate by display_name so only one entry per "logical" model is returned. + unique_models = {model.display_name: model for model in all_models} + return [ + RegisteredModel( + name=parent_model.display_name, + description=parent_model.description, + metadata=parent_model.labels, + ) + for parent_model in unique_models.values() + ] + except Exception as e: + raise RuntimeError(f"Failed to list models: {str(e)}") + + def _extract_vertex_config_from_metadata( + self, metadata: Dict[str, Any] + ) -> "VertexAIModelConfig": + """Extracts Vertex AI specific configuration from metadata dictionary. + + Args: + metadata: The metadata dictionary potentially containing config overrides. + + Returns: + A VertexAIModelConfig instance populated from metadata. + """ + # Use the module-level helper function + container_config_dict = _get_metadata_value(metadata, "container", {}) + container_config = None + if isinstance(container_config_dict, dict) and container_config_dict: + from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIContainerSpec, + ) + + container_config = VertexAIContainerSpec(**container_config_dict) + + explanation_config_dict = _get_metadata_value( + metadata, "explanation", {} + ) + explanation_config = None + if ( + isinstance(explanation_config_dict, dict) + and explanation_config_dict + ): + from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIExplanationSpec, + ) + + explanation_config = VertexAIExplanationSpec( + **explanation_config_dict + ) + + # Use the module-level helper function and correct instantiation + return VertexAIModelConfig( + # Model metadata overrides + display_name=_get_metadata_value(metadata, "display_name"), + description=_get_metadata_value(metadata, "description"), + version_description=_get_metadata_value( + metadata, "version_description" + ), + version_aliases=_get_metadata_value(metadata, "version_aliases"), + # Model artifacts overrides + artifact_uri=_get_metadata_value(metadata, "artifact_uri"), + # Model versioning overrides + is_default_version=_get_metadata_value( + metadata, "is_default_version" + ), + # Model formats overrides (less likely used here, but for completeness) + supported_deployment_resources_types=_get_metadata_value( + metadata, "supported_deployment_resources_types" + ), + supported_input_storage_formats=_get_metadata_value( + metadata, "supported_input_storage_formats" + ), + supported_output_storage_formats=_get_metadata_value( + metadata, "supported_output_storage_formats" + ), + # Container and Explanation config (parsed above) + container=container_config, + explanation=explanation_config, + # GCP Base config (from component config) + encryption_spec_key_name=_get_metadata_value( + metadata, "encryption_spec_key_name" + ), + ) + + def register_model_version( + self, + name: str, + version: Optional[str] = None, + model_source_uri: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[ModelRegistryModelMetadata] = None, + **kwargs: Any, + ) -> RegistryModelVersion: + """Register a model version to the Vertex AI model registry. + + Args: + name: Model name + version: Model version + model_source_uri: URI to model artifacts (overrides metadata if provided) + description: Model description (overrides metadata if provided) + metadata: Model metadata (expected to be a ModelRegistryModelMetadata or + equivalent serializable dict). Can contain overrides for + Vertex AI model parameters like 'display_name', 'artifact_uri', + 'version_description', 'container', 'explanation', etc. + config: Vertex AI model configuration overrides. + **kwargs: Additional arguments + + Returns: + RegistryModelVersion instance + """ + # Prepare labels with internal ZenML metadata, ensuring they are sanitized + metadata_dict = metadata.model_dump() if metadata else {} + labels = self._prepare_labels(metadata_dict) + if version: + labels["user_version"] = sanitize_vertex_label(version) + + # Extract Vertex AI specific config overrides from metadata + vertex_config = self._extract_vertex_config_from_metadata( + metadata_dict + ) + + # Use a consistently sanitized display name. Prioritize metadata, then name arg. + model_display_name_override = vertex_config.display_name + model_display_name = ( + model_display_name_override + or self._sanitize_model_display_name(name) + ) + + # Determine serving container image URI: prioritize metadata container config, + # then metadata direct key, then default. + serving_container_image_uri = "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest" # Default + if "serving_container_image_uri" in metadata_dict: + serving_container_image_uri = metadata_dict[ + "serving_container_image_uri" + ] + if vertex_config.container and vertex_config.container.image_uri: + serving_container_image_uri = vertex_config.container.image_uri + + # Determine artifact URI: prioritize direct argument, then metadata, then log warning. + final_artifact_uri = model_source_uri or vertex_config.artifact_uri + if not final_artifact_uri: + logger.warning( + "No 'artifact_uri' provided in function arguments or metadata. " + "Model registration might fail or use an unexpected artifact source." + ) + + # Determine description: prioritize direct argument, then metadata. + final_description = description or vertex_config.description + + # Build extended upload arguments for vertex.Model.upload, + # leveraging extracted config from metadata and component config for core details. + upload_arguments = { + # Core GCP config from component + "project": self.config.project_id or self.config.project, + "location": self.config.location or vertex_config.location, + # Model identification and artifacts + "display_name": model_display_name, + "artifact_uri": final_artifact_uri, + # Description and Versioning - prioritize metadata + "description": final_description, + "version_description": vertex_config.version_description, + "version_aliases": vertex_config.version_aliases, + "is_default_version": vertex_config.is_default_version + if vertex_config.is_default_version is not None + else True, + # Container configuration from metadata + "serving_container_image_uri": serving_container_image_uri, + "serving_container_predict_route": vertex_config.container.predict_route + if vertex_config.container + else None, + "serving_container_health_route": vertex_config.container.health_route + if vertex_config.container + else None, + "serving_container_command": vertex_config.container.command + if vertex_config.container + else None, + "serving_container_args": vertex_config.container.args + if vertex_config.container + else None, + "serving_container_environment_variables": vertex_config.container.env + if vertex_config.container + else None, + "serving_container_ports": vertex_config.container.ports + if vertex_config.container + else None, + # Labels and Encryption + "labels": labels, + "encryption_spec_key_name": vertex_config.encryption_spec_key_name, + } + + # Include explanation settings if provided in metadata config. + if vertex_config.explanation: + upload_arguments["explanation_metadata"] = ( + vertex_config.explanation.metadata + ) + upload_arguments["explanation_parameters"] = ( + vertex_config.explanation.parameters + ) + + # Remove any parameters that are None to avoid passing them to upload. + upload_arguments = { + k: v for k, v in upload_arguments.items() if v is not None + } + + # Try to get existing parent model, but don't fail if it doesn't exist + # Use the actual model name `name` for lookup, not the potentially overridden display name + parent_model = self._init_vertex_model(name=name, version=version) + + # If parent model exists and has same URI, return existing version + # Check against final_artifact_uri used for upload + if parent_model and parent_model.uri == final_artifact_uri: + logger.info( + f"Model version {version} targeting artifact URI " + f"'{final_artifact_uri}' already exists, skipping upload..." + ) + return self._vertex_model_to_registry_version(parent_model) + + # Set parent model resource name if it exists + if parent_model: + # Ensure the display_name matches the parent model if it exists, + # otherwise upload might create a *new* model instead of a version. + # Use the parent model's display name for the upload. + upload_arguments["display_name"] = parent_model.display_name + upload_arguments["parent_model"] = parent_model.resource_name + logger.info( + f"Found existing parent model '{parent_model.display_name}' " + f"({parent_model.resource_name}). Uploading as a new version." + ) + else: + logger.info( + f"No existing parent model found for name '{name}'. " + f"A new model named '{upload_arguments['display_name']}' will be created." + ) + + # Upload the model + try: + logger.info( + f"Uploading model to Vertex AI with arguments: { {k: v for k, v in upload_arguments.items() if k != 'labels'} }" + ) # Don't log potentially large labels dict + model = aiplatform.Model.upload(**upload_arguments) + logger.info( + f"Uploaded new model version with labels: {model.labels}" + ) + except Exception as e: + logger.error(f"Failed to upload model to Vertex AI: {e}") + # Log the arguments again on failure for easier debugging + logger.error(f"Failed upload arguments: {upload_arguments}") + raise + + return self._vertex_model_to_registry_version(model) + + def delete_model_version( + self, + name: str, + version: str, + ) -> None: + """Delete a model version from the Vertex AI model registry. + + Args: + name: Model name + version: Version string + + Raises: + RuntimeError: If the model version is not found + """ + try: + model = self._init_vertex_model(name=name, version=version) + if model is None: + raise RuntimeError( + f"Model version '{version}' for '{name}' not found." + ) + model.versioning_registry.delete_version(version) + logger.info(f"Deleted model version: {name} version {version}") + except Exception as e: + raise RuntimeError(f"Failed to delete model version: {str(e)}") + + def update_model_version( + self, + name: str, + version: str, + description: Optional[str] = None, + metadata: Optional[ModelRegistryModelMetadata] = None, + remove_metadata: Optional[List[str]] = None, + stage: Optional[ModelVersionStage] = None, + ) -> RegistryModelVersion: + """Update a model version in the Vertex AI model registry. + + Args: + name: The name of the model. + version: The version of the model. + description: The description of the model. + metadata: The metadata of the model. + remove_metadata: The metadata to remove from the model. + stage: The stage of the model. + + Returns: + The updated model version. + + Raises: + RuntimeError: If the model version is not found + """ + try: + parent_model = self._init_vertex_model(name=name, version=version) + sanitized_version = sanitize_vertex_label(version) + target_version = None + for v in parent_model.list(): + if v.labels.get("user_version") == sanitized_version: + target_version = v + break + if target_version is None: + raise RuntimeError( + f"Model version '{version}' for '{name}' not found." + ) + labels = target_version.labels or {} + if metadata: + metadata_dict = metadata.model_dump() + for key, value in metadata_dict.items(): + labels[sanitize_vertex_label(key)] = sanitize_vertex_label( + str(value) + ) + if remove_metadata: + for key in remove_metadata: + labels.pop(sanitize_vertex_label(key), None) + if stage: + labels["stage"] = stage.value.lower() + target_version.update(description=description, labels=labels) + except Exception as e: + raise RuntimeError(f"Failed to update model version: {str(e)}") + return self.get_model_version(name, version) + + def get_model_version( + self, name: str, version: str + ) -> RegistryModelVersion: + """Get a model version from the Vertex AI model registry using the version label. + + Args: + name: The name of the model. + version: The version of the model. + + Returns: + The registered model version. + + Raises: + RuntimeError: If the model version is not found + """ + try: + parent_model = self._init_vertex_model(name=name, version=version) + if parent_model is None: + raise RuntimeError( + f"Model version '{version}' for '{name}' not found." + ) + return self._vertex_model_to_registry_version(parent_model) + except Exception as e: + raise RuntimeError(f"Failed to get model version: {str(e)}") + + def list_model_versions( + self, + name: Optional[str] = None, + model_source_uri: Optional[str] = None, + metadata: Optional[ModelRegistryModelMetadata] = None, + stage: Optional[ModelVersionStage] = None, + count: Optional[int] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + order_by_date: Optional[str] = None, + **kwargs: Any, + ) -> List[RegistryModelVersion]: + """List model versions from the Vertex AI model registry. + + Args: + name: The name of the model. + model_source_uri: The URI of the model source. + metadata: The metadata of the model. + stage: The stage of the model. + count: The number of model versions to return. + created_after: The date after which the model versions were created. + created_before: The date before which the model versions were created. + order_by_date: The date to order the model versions by. + **kwargs: Additional arguments + + Returns: + The registered model versions. + + Raises: + RuntimeError: If the model versions are not found + """ + credentials, project_id = self._get_authentication() + location = self.config.location + filter_expr = [] + if name: + filter_expr.append( + f"display_name={self._sanitize_model_display_name(name)}" + ) + if metadata: + for key, value in metadata.dict().items(): + filter_expr.append( + f"labels.{sanitize_vertex_label(key)}={sanitize_vertex_label(str(value))}" + ) + if created_after: + filter_expr.append(f"create_time>{created_after.isoformat()}") + if created_before: + filter_expr.append(f"create_time<{created_before.isoformat()}") + + filter_str = " AND ".join(filter_expr) if filter_expr else None + + try: + model = aiplatform.Model( + project=project_id, + location=location, + filter=filter_str, + credentials=credentials, + ) + versions = model.versioning_registry.list_versions() + results = [ + self._vertex_model_to_registry_version(v) for v in versions + ] + if count: + results = results[:count] + return results + except Exception as e: + raise RuntimeError(f"Failed to list model versions: {str(e)}") + + def load_model_version( + self, + name: str, + version: str, + **kwargs: Any, + ) -> Any: + """Load a model version from the Vertex AI model registry using label-based lookup. + + Args: + name: The name of the model. + version: The version of the model. + **kwargs: Additional arguments + + Returns: + The loaded model version. + + Raises: + RuntimeError: If the model version is not found + """ + try: + parent_model = self._init_vertex_model(name=name, version=version) + assert isinstance(parent_model, aiplatform.Model) + return parent_model + except Exception as e: + raise RuntimeError(f"Failed to load model version: {str(e)}") + + def get_model_uri_artifact_store( + self, + model_version: RegistryModelVersion, + ) -> str: + """Get the model URI artifact store. + + Args: + model_version: The model version. + + Returns: + The model URI artifact store. + """ + return model_version.model_source_uri + + def _vertex_model_to_registry_version( + self, model: aiplatform.Model + ) -> RegistryModelVersion: + """Convert Vertex AI model to ZenML RegistryModelVersion. + + Args: + model: Vertex AI Model instance + + Returns: + RegistryModelVersion instance + """ + # Extract stage from labels if present + stage = ModelVersionStage.NONE + if model.labels and "stage" in model.labels: + try: + stage = ModelVersionStage(model.labels["stage"].upper()) + except ValueError: + pass + + # Get parent model for registered_model field + try: + registered_model = RegisteredModel( + name=model.display_name, + description=model.description, + metadata=model.labels, + ) + except Exception as e: + logger.warning( + f"Failed to get parent model for version: {model.resource_name}: {e}" + ) + registered_model = RegisteredModel( + name=model.display_name if model.display_name else "unknown", + description=model.description if model.description else "", + metadata=model.labels if model.labels else {}, + ) + + model_version_metadata = model.labels + model_version_metadata["resource_name"] = model.resource_name + return RegistryModelVersion( + registered_model=registered_model, + version=model.version_id, + model_source_uri=model.uri, + model_format="Custom", # Vertex AI doesn't provide format info + description=model.description, + metadata=model_version_metadata, + created_at=model.create_time, + last_updated_at=model.update_time, + stage=stage, + ) + + def _sanitize_model_display_name(self, name: str) -> str: + """Sanitize the model display name to conform to Vertex AI limits. + + Args: + name: The name of the model. + + Returns: + The sanitized model name. + """ + name = sanitize_vertex_label(name) + if len(name) > MAX_DISPLAY_NAME_LENGTH: + logger.warning( + f"Model name '{name}' exceeds {MAX_DISPLAY_NAME_LENGTH} characters; truncating." + ) + name = name[:MAX_DISPLAY_NAME_LENGTH] + return name diff --git a/src/zenml/integrations/gcp/services/__init__.py b/src/zenml/integrations/gcp/services/__init__.py new file mode 100644 index 00000000000..392a48e9694 --- /dev/null +++ b/src/zenml/integrations/gcp/services/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the Vertex Service.""" + +from zenml.integrations.gcp.services.vertex_deployment import ( # noqa + VertexDeploymentConfig, + VertexDeploymentService, +) + +__all__ = ["VertexDeploymentConfig", "VertexDeploymentService"] \ No newline at end of file diff --git a/src/zenml/integrations/gcp/services/vertex_deployment.py b/src/zenml/integrations/gcp/services/vertex_deployment.py new file mode 100644 index 00000000000..a1c7d8b8610 --- /dev/null +++ b/src/zenml/integrations/gcp/services/vertex_deployment.py @@ -0,0 +1,465 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Vertex AI Deployment service.""" + +from datetime import datetime +from typing import Any, Dict, Generator, List, Optional, Tuple, cast + +from google.api_core import retry +from google.cloud import aiplatform +from pydantic import Field, PrivateAttr + +from zenml.client import Client +from zenml.enums import ServiceState +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIEndpointConfig, +) +from zenml.integrations.gcp.utils import sanitize_vertex_label +from zenml.logger import get_logger +from zenml.models.v2.misc.service import ServiceType +from zenml.services import ServiceStatus +from zenml.services.service import BaseDeploymentService, ServiceConfig +from zenml.services.service_endpoint import ( + BaseServiceEndpoint, + ServiceEndpointConfig, +) + +logger = get_logger(__name__) + +# Constants +POLLING_TIMEOUT = 1800 # 30 minutes +RETRY_DEADLINE = 600 # 10 minutes +UUID_SLICE_LENGTH: int = 8 + +# Retry configuration for transient errors +retry_config = retry.Retry( + initial=1.0, # Initial delay in seconds + maximum=60.0, # Maximum delay + multiplier=2.0, # Delay multiplier + deadline=RETRY_DEADLINE, + predicate=retry.if_transient_error, +) + + +class VertexDeploymentConfig(VertexAIEndpointConfig, ServiceConfig): + """Vertex AI service configurations.""" + + def get_vertex_deployment_labels(self) -> Dict[str, str]: + """Generate labels for the VertexAI deployment from the service configuration. + + Returns: + A dictionary of labels for the VertexAI deployment. + """ + labels = self.labels or {} + labels["managed_by"] = "zenml" + if self.pipeline_name: + labels["pipeline-name"] = sanitize_vertex_label(self.pipeline_name) + if self.pipeline_step_name: + labels["step-name"] = sanitize_vertex_label( + self.pipeline_step_name + ) + if self.model_name: + labels["model-name"] = sanitize_vertex_label(self.model_name) + if self.service_name: + labels["service-name"] = sanitize_vertex_label(self.service_name) + if self.display_name: + labels["display-name"] = sanitize_vertex_label( + self.display_name + ) or sanitize_vertex_label(self.name) + return labels + + +class VertexPredictionServiceEndpointConfig(ServiceEndpointConfig): + """Vertex AI Prediction Service Endpoint.""" + + endpoint_name: Optional[str] = None + deployed_model_id: Optional[str] = None + endpoint_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + state: Optional[str] = None + + +class VertexServiceStatus(ServiceStatus): + """Vertex AI service status.""" + + +class VertexPredictionServiceEndpoint(BaseServiceEndpoint): + """Vertex AI Prediction Service Endpoint.""" + + config: VertexPredictionServiceEndpointConfig + + +class VertexDeploymentService(BaseDeploymentService): + """Vertex AI model deployment service.""" + + SERVICE_TYPE = ServiceType( + name="vertex-deployment", + type="model-serving", + flavor="vertex", + description="Vertex AI inference endpoint prediction service", + ) + config: VertexDeploymentConfig + status: VertexServiceStatus = Field( + default_factory=lambda: VertexServiceStatus() + ) + _project_id: Optional[str] = PrivateAttr(default=None) + _credentials: Optional[Any] = PrivateAttr(default=None) + + def _initialize_gcp_clients(self) -> None: + """Initialize GCP clients with consistent credentials.""" + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + model_deployer = cast( + VertexModelDeployer, Client().active_stack.model_deployer + ) + + # Get credentials from model deployer + ( + self._credentials, + self._project_id, + ) = model_deployer._get_authentication() + + def __init__(self, config: VertexDeploymentConfig, **attrs: Any): + """Initialize the Vertex AI deployment service. + + Args: + config: The configuration for the Vertex AI deployment service. + **attrs: Additional attributes for the service. + """ + super().__init__(config=config, **attrs) + self._initialize_gcp_clients() + + @property + def prediction_url(self) -> Optional[str]: + """The prediction URI exposed by the prediction service. + + Returns: + The prediction URI exposed by the prediction service. + """ + endpoints = self.get_endpoints() + if not endpoints: + return None + endpoint = endpoints[0] + return f"https://{self.config.location}-aiplatform.googleapis.com/v1/{endpoint.resource_name}" + + def get_endpoints(self) -> List[aiplatform.Endpoint]: + """Get all endpoints for the current project and location. + + Returns: + List of Vertex AI endpoints + """ + try: + # Use proper filtering and pagination + display_name = self.config.name or self.config.display_name + assert display_name is not None + display_name = sanitize_vertex_label(display_name) + return list( + aiplatform.Endpoint.list( + filter=f"labels.managed_by=zenml AND labels.display-name={display_name}", + project=self._project_id, + location=self.config.location, + credentials=self._credentials, + ) + ) + except Exception as e: + logger.error(f"Failed to list endpoints: {e}") + return [] + + def _generate_endpoint_name(self) -> str: + """Generate a unique name for the Vertex AI Inference Endpoint. + + Returns: + Generated endpoint name + """ + # Make name more descriptive and conformant + sanitized_model_name = sanitize_vertex_label( + self.config.display_name or self.config.name + ) + return f"{sanitized_model_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}" + + def _get_model_id(self, name: str) -> str: + """Helper to construct a full model ID from a given model name. + + Args: + name: The name of the model. + + Returns: + The full model ID. + """ + model_id = f"projects/{self._project_id}/locations/{self.config.location}/models/{name}" + return model_id + + def _verify_model_exists(self) -> aiplatform.Model: + """Verify the model exists and return it. + + Returns: + Vertex AI Model instance + """ + if self.config.model_name.startswith("projects/"): + model_name = self.config.model_name + else: + model_name = self._get_model_id(self.config.model_name) + # Remove version suffix if present + if "@" in model_name: + model_name = model_name.split("@")[0] + logger.info(f"Model name: {model_name}") + logger.info(f"Project ID: {self._project_id}") + logger.info(f"Location: {self.config.location}") + model = aiplatform.Model( + model_name=model_name, + project=self._project_id, + location=self.config.location, + credentials=self._credentials, + ) + logger.info(f"Found model to deploy: {model.resource_name}") + return model + + def _deploy_model( + self, model: aiplatform.Model, endpoint: aiplatform.Endpoint + ) -> None: + """Deploy model to Vertex AI endpoint. + + Args: + model: The model to deploy. + endpoint: The endpoint to deploy the model to. + """ + # Prepare deployment configuration + deploy_kwargs = { + "model": model, + "deployed_model_display_name": self.config.display_name + or self.config.name, + "traffic_percentage": 100, + "sync": False, + } + logger.info( + f"Deploying model to endpoint with kwargs: {deploy_kwargs}" + ) + # Add container configuration if specified + if self.config.container: + deploy_kwargs.update( + { + "container_image_uri": self.config.container.image_uri, + "container_ports": self.config.container.ports, + "container_predict_route": self.config.container.predict_route, + "container_health_route": self.config.container.health_route, + "container_env": self.config.container.env, + } + ) + + # Add resource configuration if specified + if self.config.resources: + deploy_kwargs.update( + { + "machine_type": self.config.resources.machine_type, + "min_replica_count": self.config.resources.min_replica_count, + "max_replica_count": self.config.resources.max_replica_count, + "accelerator_type": self.config.resources.accelerator_type, + "accelerator_count": self.config.resources.accelerator_count, + } + ) + + # Add explanation configuration if specified + if self.config.explanation: + deploy_kwargs.update( + { + "explanation_metadata": self.config.explanation.metadata, + "explanation_parameters": self.config.explanation.parameters, + } + ) + + # Add service account if specified + if self.config.service_account: + deploy_kwargs["service_account"] = self.config.service_account + + # Add network configuration if specified + if self.config.network: + deploy_kwargs["network"] = self.config.network + + # Add encryption key if specified + if self.config.encryption_spec_key_name: + deploy_kwargs["encryption_spec_key_name"] = ( + self.config.encryption_spec_key_name + ) + + # Deploy model + logger.info( + f"Deploying model to endpoint with kwargs: {deploy_kwargs}" + ) + endpoint.deploy(**deploy_kwargs) + + def provision(self) -> None: + """Provision or update remote Vertex AI deployment instance. + + Raises: + Exception: if model deployment fails + """ + # First verify model exists + model = self._verify_model_exists() + logger.info(f"Found model to deploy: {model.resource_name}") + + # Get or create endpoint + if self.config.existing_endpoint: + endpoint = aiplatform.Endpoint( + endpoint_name=self.config.existing_endpoint, + location=self.config.location, + credentials=self._credentials, + ) + logger.info(f"Using existing endpoint: {endpoint.resource_name}") + else: + endpoint_name = self._generate_endpoint_name() + endpoint = aiplatform.Endpoint.create( + display_name=endpoint_name, + location=self.config.location, + encryption_spec_key_name=self.config.encryption_spec_key_name, + labels=self.config.get_vertex_deployment_labels(), + credentials=self._credentials, + ) + logger.info(f"Created new endpoint: {endpoint.resource_name}") + # Deploy model with retries for transient errors + try: + self._deploy_model(model, endpoint) + + logger.info( + f"Model {model.resource_name} deployed to endpoint {endpoint.resource_name}" + ) + except Exception as e: + self.status.update_state( + ServiceState.ERROR, f"Deployment failed: {str(e)}" + ) + raise + + self.status.update_state(ServiceState.ACTIVE) + + logger.info( + f"Deployment completed successfully. " + f"Endpoint: {endpoint.resource_name}" + ) + + def deprovision(self, force: bool = False) -> None: + """Deprovision the Vertex AI deployment. + + Args: + force: Whether to force deprovision + + Raises: + RuntimeError: if endpoint deletion fails + """ + endpoints = self.get_endpoints() + if endpoints: + try: + endpoint = endpoints[0] + endpoint.undeploy_all() + endpoint.delete() + logger.info( + f"Deprovisioned endpoint: {endpoint.resource_name}" + ) + self.status.update_state(ServiceState.INACTIVE) + except Exception as e: + logger.error(f"Failed to deprovision endpoint: {e}") + self.status.update_state( + ServiceState.ERROR, f"Failed to deprovision endpoint: {e}" + ) + else: + try: + endpoint = aiplatform.Endpoint( + endpoint_name=self._generate_endpoint_name(), + location=self.config.location, + credentials=self._credentials, + ) + + # Undeploy model + endpoint.undeploy_all() + + # Delete endpoint if we created it + if not self.config.existing_endpoint: + endpoint.delete() + + logger.info( + f"Deprovisioned endpoint: {endpoint.resource_name}" + ) + + self.status.update_state(ServiceState.INACTIVE) + + except Exception as e: + error_msg = f"Failed to deprovision deployment: {str(e)}" + if not force: + logger.error(error_msg) + self.status.update_state(ServiceState.ERROR, error_msg) + raise RuntimeError(error_msg) + else: + logger.warning( + f"Error during forced deprovision (ignoring): {error_msg}" + ) + self.status.update_state(ServiceState.INACTIVE) + + def get_logs( + self, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Retrieve logs for the Vertex AI deployment (not supported). + + Args: + follow: Whether to follow the logs. + tail: The number of lines to tail. + + Yields: + Log entries as strings, but logs are not supported for Vertex AI. + """ + logger.warning("Logs are not supported for Vertex AI") + yield from () + + def check_status(self) -> Tuple[ServiceState, str]: + """Check the status of the deployment by validating if an endpoint exists and if it has deployed models. + + Returns: + A tuple containing the deployment's state and a status message. + """ + try: + endpoints = self.get_endpoints() + if not endpoints: + return ServiceState.INACTIVE, "No endpoint found." + + endpoint = endpoints[0] + deployed_models = [] + if hasattr(endpoint, "list_models"): + try: + deployed_models = endpoint.list_models() + except Exception as e: + logger.warning(f"Failed to list models for endpoint: {e}") + elif hasattr(endpoint, "deployed_models"): + deployed_models = endpoint.deployed_models or [] + + if deployed_models and len(deployed_models) > 0: + return ServiceState.ACTIVE, "" + else: + return ( + ServiceState.PENDING_STARTUP, + "Endpoint deployment is in progress.", + ) + except Exception as e: + return ServiceState.ERROR, f"Deployment check failed: {e}" + + @property + def is_running(self) -> bool: + """Check if the service is running. + + Returns: + True if the service is running, False otherwise. + """ + self.update_status() + return self.status.state == ServiceState.ACTIVE diff --git a/src/zenml/integrations/gcp/utils.py b/src/zenml/integrations/gcp/utils.py new file mode 100644 index 00000000000..cbdae568b36 --- /dev/null +++ b/src/zenml/integrations/gcp/utils.py @@ -0,0 +1,42 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""GCP utils.""" + +import re + + +def sanitize_vertex_label(value: str) -> str: + """Sanitize a label value to comply with Vertex AI requirements. + + Args: + value: The label value to sanitize + + Returns: + Sanitized label value + """ + if not value: + return "" + + # Convert to lowercase + value = value.lower() + + # Replace any character that's not lowercase letter, number, dash or underscore + value = re.sub(r"[^a-z0-9\-_]", "-", value) + + # Ensure it starts with a letter/number by prepending 'x' if needed + if not value[0].isalnum(): + value = f"x{value}" + + # Truncate to 63 chars to stay under limit + return value[:63] diff --git a/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py b/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py index b11f7fe7080..a796deb4863 100644 --- a/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py +++ b/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py @@ -13,8 +13,10 @@ # permissions and limitations under the License. """Implementation of the sklearn materializer.""" +import os from typing import Any, ClassVar, Tuple, Type +import cloudpickle from sklearn.base import ( BaseEstimator, BiclusterMixin, @@ -29,13 +31,20 @@ ) from zenml.enums import ArtifactType +from zenml.environment import Environment +from zenml.logger import get_logger from zenml.materializers.cloudpickle_materializer import ( + DEFAULT_FILENAME, CloudpickleMaterializer, ) +logger = get_logger(__name__) + +SKLEARN_MODEL_FILENAME = "model.pkl" + class SklearnMaterializer(CloudpickleMaterializer): - """Materializer to read data to and from sklearn.""" + """Materializer to read data to and from sklearn with backward compatibility.""" ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( BaseEstimator, @@ -50,3 +59,66 @@ class SklearnMaterializer(CloudpickleMaterializer): TransformerMixin, ) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL + + def load(self, data_type: Type[Any]) -> Any: + """Reads a sklearn model from pickle file with backward compatibility. + + Args: + data_type: The data type of the artifact. + + Returns: + The loaded sklearn model. + + Raises: + FileNotFoundError: if model file not found + """ + # First try to load from model.pkl + model_filepath = os.path.join(self.uri, SKLEARN_MODEL_FILENAME) + artifact_filepath = os.path.join(self.uri, DEFAULT_FILENAME) + + # Check which file exists and load accordingly + if self.artifact_store.exists(model_filepath): + filepath = model_filepath + elif self.artifact_store.exists(artifact_filepath): + logger.info( + f"Loading from legacy filepath {artifact_filepath}. Future saves " + f"will use {model_filepath}" + ) + filepath = artifact_filepath + else: + raise FileNotFoundError( + f"Neither {model_filepath} nor {artifact_filepath} found in artifact store" + ) + + # validate python version before loading + source_python_version = self._load_python_version() + current_python_version = Environment().python_version() + if ( + source_python_version != "unknown" + and source_python_version != current_python_version + ): + logger.warning( + f"Your artifact was materialized under Python version " + f"'{source_python_version}' but you are currently using " + f"'{current_python_version}'. This might cause unexpected " + "behavior since pickle is not reproducible across Python " + "versions. Attempting to load anyway..." + ) + + # Load the model + with self.artifact_store.open(filepath, "rb") as fid: + return cloudpickle.load(fid) + + def save(self, data: Any) -> None: + """Saves a sklearn model to pickle file using the new filename. + + Args: + data: The sklearn model to save. + """ + # Save python version for validation on loading + self._save_python_version() + + # Save using the new filename + filepath = os.path.join(self.uri, SKLEARN_MODEL_FILENAME) + with self.artifact_store.open(filepath, "wb") as fid: + cloudpickle.dump(data, fid) diff --git a/src/zenml/model_deployers/base_model_deployer.py b/src/zenml/model_deployers/base_model_deployer.py index c881e8b3d12..70a1bc7caad 100644 --- a/src/zenml/model_deployers/base_model_deployer.py +++ b/src/zenml/model_deployers/base_model_deployer.py @@ -28,7 +28,7 @@ from uuid import UUID from zenml.client import Client -from zenml.enums import StackComponentType +from zenml.enums import ServiceState, StackComponentType from zenml.logger import get_logger from zenml.models.v2.misc.service import ServiceType from zenml.services import BaseService, ServiceConfig @@ -180,6 +180,12 @@ def deploy_model( logger.info( f"Existing model server found for {config.name or config.model_name} with the exact same configuration. Returning the existing service named {services[0].config.service_name}." ) + status, _ = services[0].check_status() + if status != ServiceState.ACTIVE: + logger.info( + f"Service found for {config.name or config.model_name} is not active. Starting the service." + ) + services[0].start(timeout=timeout) return services[0] else: # Find existing model server diff --git a/src/zenml/model_registries/base_model_registry.py b/src/zenml/model_registries/base_model_registry.py index 0cc8c0bbd1a..d9fe07baee8 100644 --- a/src/zenml/model_registries/base_model_registry.py +++ b/src/zenml/model_registries/base_model_registry.py @@ -70,6 +70,15 @@ class ModelRegistryModelMetadata(BaseModel): zenml_step_name: Optional[str] = None zenml_project: Optional[str] = None + @property + def managed_by(self) -> str: + """Returns the managed by attribute. + + Returns: + The managed by attribute. + """ + return "zenml" + @property def custom_attributes(self) -> Dict[str, str]: """Returns a dictionary of custom attributes. diff --git a/src/zenml/services/service.py b/src/zenml/services/service.py index 24f85d07f3a..e32cf32d231 100644 --- a/src/zenml/services/service.py +++ b/src/zenml/services/service.py @@ -36,6 +36,7 @@ from zenml.console import console from zenml.enums import ServiceState from zenml.logger import get_logger +from zenml.model.model import Model from zenml.models.v2.misc.service import ServiceType from zenml.services.service_endpoint import BaseServiceEndpoint from zenml.services.service_monitor import HTTPEndpointHealthMonitor @@ -110,6 +111,7 @@ class ServiceConfig(BaseTypedModel): pipeline_name: name of the pipeline that spun up the service pipeline_step_name: name of the pipeline step that spun up the service run_name: name of the pipeline run that spun up the service. + zenml_model: the ZenML model object to be deployed. """ name: str = "" diff --git a/tests/unit/services/test_service.py b/tests/unit/services/test_service.py index 080d53c081c..4c6f47e3894 100644 --- a/tests/unit/services/test_service.py +++ b/tests/unit/services/test_service.py @@ -62,7 +62,9 @@ def base_service(): return TestService( uuid=UUID("12345678-1234-5678-1234-567812345678"), admin_state=ServiceState.ACTIVE, - config=ServiceConfig(name="test_service", param1="value1", param2=2), + config=ServiceConfig( + name="test_service", param1="value1", param2=2 + ), status=ServiceStatus( state=ServiceState.ACTIVE, last_error="", @@ -78,13 +80,15 @@ def test_from_model(service_response): assert isinstance(service, TestService) assert service.uuid == service_response.id assert service.admin_state == service_response.admin_state - assert dict(service.config) == service_response.config + assert ( + service.config.model_dump(exclude_unset=True) + == service_response.config + ) assert dict(service.status) == service_response.status assert service.SERVICE_TYPE["type"] == service_response.service_type.type assert ( service.SERVICE_TYPE["flavor"] == service_response.service_type.flavor ) - assert service.endpoint == service_response.endpoint def test_update_status(base_service, monkeypatch):