Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@
StackResponseMetadata,
StackResponseResources
)
from zenml.models.v2.misc.param_groups import (
PipelineRunIdentifier,
StepRunIdentifier,
VersionedIdentifier,
)
from zenml.models.v2.misc.statistics import (
ProjectStatistics,
ServerStatistics,
Expand Down Expand Up @@ -874,4 +879,7 @@
"ProjectStatistics",
"PipelineRunDAG",
"ExceptionInfo",
"VersionedIdentifier",
"PipelineRunIdentifier",
"StepRunIdentifier",
]
113 changes: 113 additions & 0 deletions src/zenml/models/v2/misc/param_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Parameter group classes."""

from uuid import UUID

from pydantic import BaseModel, model_validator


class VersionedIdentifier(BaseModel):
"""Class grouping identifiers for entities resolved by UUID or name&version."""

id: UUID | None
name: str | None
version: str | None

@model_validator(mode="after")
def _validate_options(self) -> "VersionedIdentifier":
if self.id and self.name:
raise ValueError(
"You can use only identification option at a time."
"Use either id or name."
)

if not (self.id or self.name):
raise ValueError(
"You have to use at least one identification option."
"Use either id or name."
)

if bool(self.name) ^ bool(self.version):
raise ValueError("You need to specify both name and version.")

return self


class PipelineRunIdentifier(BaseModel):
"""Class grouping different pipeline run identifiers."""

id: UUID | None
name: str | None
prefix: str | None

@property
def value(self) -> str | UUID:
"""Resolves the set value out of id, name, prefix etc.
Returns:
The id/name/prefix (if set, in this exact order).
"""
return self.id or self.name or self.prefix # type: ignore[return-value]

@model_validator(mode="after")
def _validate_options(self) -> "PipelineRunIdentifier":
options = [
bool(self.id),
bool(self.name),
bool(self.prefix),
]

if sum(options) > 1:
raise ValueError(
"You can use only identification option at a time."
"Use either id or name or prefix."
)

if sum(options) == 0:
raise ValueError(
"You have to use at least one identification option."
"Use either id or name or prefix."
)

return self


class StepRunIdentifier(BaseModel):
"""Class grouping different step run identifiers."""

id: UUID | None
name: str
pipeline: PipelineRunIdentifier | None

@model_validator(mode="after")
def _validate_options(self) -> "StepRunIdentifier":
if self.id and self.name:
raise ValueError(
"You can use only identification option at a time."
"Use either id or name."
)

if not (self.id or self.name):
raise ValueError(
"You have to use at least one identification option."
"Use either id or name."
)

if bool(self.name) ^ bool(self.pipeline):
raise ValueError(
"To identify a run by name you need to specify a pipeline run identifier."
)

return self
26 changes: 26 additions & 0 deletions src/zenml/models/v2/misc/run_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Utility classes for modeling run metadata."""

from datetime import datetime
from typing import Any
from uuid import UUID

from pydantic import BaseModel, Field
Expand All @@ -28,6 +29,31 @@ class RunMetadataResource(BaseModel):
id: UUID = Field(title="The ID of the resource.")
type: MetadataResourceTypes = Field(title="The type of the resource.")

def __eq__(self, other: Any):
"""Overrides equality operator.

Args:
other: The object to compare.

Returns:
True if the object is equal to the given object.

Raises:
TypeError: If the object is not an instance of RunMetadataResource.
"""
if not isinstance(other, RunMetadataResource):
raise TypeError(f"Expected RunMetadataResource, got {type(other)}")

return hash(other) == hash(self)

def __hash__(self) -> int:
"""Overrides hash operator.

Returns:
The hash value of the object.
"""
return hash(f"{str(self.id)}_{self.type.value}")


class RunMetadataEntry(BaseModel):
"""Utility class to sort/list run metadata entries."""
Expand Down
146 changes: 144 additions & 2 deletions src/zenml/utils/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
# permissions and limitations under the License.
"""Utility functions to handle metadata for ZenML entities."""

from typing import Dict, List, Optional, Union, overload
from typing import Dict, List, Optional, Set, Union, overload
from uuid import UUID

from zenml.client import Client
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.logger import get_logger
from zenml.metadata.metadata_types import MetadataType
from zenml.models import RunMetadataResource
from zenml.models import (
PipelineRunIdentifier,
RunMetadataResource,
StepRunIdentifier,
VersionedIdentifier,
)
from zenml.steps.step_context import get_step_context

logger = get_logger(__name__)
Expand Down Expand Up @@ -366,3 +371,140 @@ def log_metadata(
resources=resources,
publisher_step_id=publisher_step_id,
)


def bulk_log_metadata(
metadata: Dict[str, MetadataType],
pipeline_runs: list[PipelineRunIdentifier] | None = None,
step_runs: list[StepRunIdentifier] | None = None,
artifact_versions: list[VersionedIdentifier] | None = None,
model_versions: list[VersionedIdentifier] | None = None,
infer_models: bool = False,
infer_artifacts: bool = False,
) -> None:
"""Logs metadata for multiple entities in a single invocation.

Args:
metadata: The metadata to log.
pipeline_runs: A list of pipeline runs to log metadata for.
step_runs: A list of step runs to log metadata for.
artifact_versions: A list of artifact versions to log metadata for.
model_versions: A list of model versions to log metadata for.
infer_models: Flag - when enabled infer model to log metadata for from step context.
infer_artifacts: Flag - when enabled infer artifact to log metadata for from step context.

Raises:
ValueError: If options are not passed correctly (infer options with explicit declarations) or
invocation with `infer` options is done outside of a step context.
"""
client = Client()

resources: Set[RunMetadataResource] = set()

if infer_models and model_versions:
raise ValueError(
"You can either specify model versions or use the infer option."
)

if infer_artifacts and artifact_versions:
raise ValueError(
"You can either specify artifact versions or use the infer option."
)

try:
step_context = get_step_context()
except RuntimeError:
step_context = None

if (infer_models or infer_artifacts) and step_context is None:
raise ValueError(
"Infer options can be used only within a step function code."
)

# resolve pipeline runs and add metadata resources

for pipeline in pipeline_runs or []:
if not pipeline.id:
pipeline.id = client.get_pipeline_run(
name_id_or_prefix=pipeline.value
).id
resources.add(
RunMetadataResource(
id=pipeline.id, type=MetadataResourceTypes.PIPELINE_RUN
)
)

# resolve step runs and add metadata resources

for step in step_runs or []:
if not step.id:
step.id = (
client.get_pipeline_run(name_id_or_prefix=step.pipeline.value)
.steps[step.name]
.id
)

resources.add(
RunMetadataResource(
id=step.id, type=MetadataResourceTypes.STEP_RUN
)
)

# resolve artifacts and add metadata resources

for artifact_version in artifact_versions or []:
if not artifact_version.id:
artifact_version.id = client.get_artifact_version(
name_id_or_prefix=artifact_version.name,
version=artifact_version.version,
).id
resources.add(
RunMetadataResource(
id=artifact_version.id,
type=MetadataResourceTypes.ARTIFACT_VERSION,
)
)

# resolve models and add metadata resources

for model_version in model_versions or []:
if not model_version.id:
model_version.id = client.get_model_version(
model_name_or_id=model_version.name,
model_version_name_or_number_or_id=model_version.version,
).id
resources.add(
RunMetadataResource(
id=model_version.id, type=MetadataResourceTypes.MODEL_VERSION
)
)

# infer models - resolve from step context

if infer_models and not step_context.model_version:
raise ValueError(
"The step context does not feature any model versions."
)
else:
resources.add(
RunMetadataResource(
id=step_context.model_version.id,
type=MetadataResourceTypes.MODEL_VERSION,
)
)

# infer artifacts - resolve from step context

if infer_artifacts:
step_output_names = list(step_context._outputs.keys())

for artifact_name in step_output_names:
step_context.add_output_metadata(
metadata=metadata, output_name=artifact_name
)

client.create_run_metadata(
metadata=metadata,
resources=list(resources),
publisher_step_id=None,
)
Loading
Loading