Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# orchestrator constants
ORCHESTRATOR_DOCKER_IMAGE_KEY = "orchestrator"
ENV_ZENML_DAG_RUNNER_WORKER_COUNT = "ZENML_DAG_RUNNER_WORKER_COUNT"

# deployer constants
DEPLOYER_DOCKER_IMAGE_KEY = "deployer"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

from pydantic import BaseModel

from zenml.constants import (
ENV_ZENML_DAG_RUNNER_WORKER_COUNT,
handle_int_env_var,
)
from zenml.logger import get_logger
from zenml.utils.enum_utils import StrEnum

Expand Down Expand Up @@ -135,8 +139,11 @@ def __init__(
self.interrupt_check_interval = interrupt_check_interval
self.max_parallelism = max_parallelism
self.shutdown_event = threading.Event()
worker_count = handle_int_env_var(
ENV_ZENML_DAG_RUNNER_WORKER_COUNT, 10
)
self.startup_executor = ThreadPoolExecutor(
max_workers=10, thread_name_prefix="DagRunner-Startup"
max_workers=worker_count, thread_name_prefix="DagRunner-Startup"
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import threading
import time
from contextlib import nullcontext
from typing import List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, cast
from uuid import UUID

from kubernetes import client as k8s_client
Expand Down Expand Up @@ -62,6 +62,7 @@
PipelineRunUpdate,
PipelineSnapshotResponse,
RunMetadataResource,
StepRunResponse,
)
from zenml.orchestrators import publish_utils
from zenml.orchestrators.step_run_utils import (
Expand Down Expand Up @@ -327,7 +328,7 @@ def main() -> None:
pipeline_run=pipeline_run,
stack=active_stack,
)
step_runs = {}
step_runs: Dict[str, StepRunResponse] = {}

base_labels = {
"project_id": kube_utils.sanitize_label(str(snapshot.project_id)),
Expand All @@ -346,7 +347,9 @@ def _cache_step_run_if_possible(step_name: str) -> bool:
step_name
)
try:
step_run_request_factory.populate_request(step_run_request)
step_run_request_factory.populate_request(
step_run_request, step_runs=step_runs
)
except Exception as e:
logger.error(
f"Failed to populate step run request for step {step_name}: {e}"
Expand Down Expand Up @@ -525,15 +528,22 @@ def start_step_job(node: Node) -> NodeStatus:
job_manifest=job_manifest,
)

Client().create_run_metadata(
metadata={"step_jobs": {step_name: job_name}},
resources=[
RunMetadataResource(
id=pipeline_run.id,
type=MetadataResourceTypes.PIPELINE_RUN,
)
],
)
try:
Client().create_run_metadata(
metadata={"step_jobs": {step_name: job_name}},
resources=[
RunMetadataResource(
id=pipeline_run.id,
type=MetadataResourceTypes.PIPELINE_RUN,
)
],
)
except Exception as e:
logger.warning(
"Failed to create run metadata for step `%s`: %s",
step_name,
str(e),
)

node.metadata["job_name"] = job_name

Expand Down
1 change: 1 addition & 0 deletions src/zenml/orchestrators/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def get_cached_step_run(cache_key: str) -> Optional["StepRunResponse"]:
status=ExecutionStatus.COMPLETED,
sort_by=f"{SorterOps.DESCENDING}:created",
size=1,
hydrate=True,
).items

if cache_candidates:
Expand Down
1 change: 1 addition & 0 deletions tests/unit/orchestrators/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def test_fetching_cached_step_run_queries_cache_candidates(
status=ExecutionStatus.COMPLETED,
sort_by=f"{SorterOps.DESCENDING}:created",
size=1,
hydrate=True,
)


Expand Down
Loading