Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ class VertexAIWorkerVariables(BaseVariables):
"within the provided ip ranges. Otherwise, the job will be deployed to "
"any ip ranges under the provided VPC network.",
)
scheduling: Optional[dict[str, Any]] = Field(
default=None,
title="Scheduling Options",
description=(
"A dictionary with scheduling options for a CustomJob, "
"these are parameters related to queuing, and scheduling custom jobs. "
"If unspecified default scheduling options are used. "
"The 'maximum_run_time_hours' variable will take precedance over the "
"'scheduling.timeout' field for backward compatibility."
),
)
service_account_name: Optional[str] = Field(
default=None,
title="Service Account Name",
Expand Down Expand Up @@ -235,6 +246,7 @@ class VertexAIWorkerJobConfiguration(BaseJobConfiguration):
"network": "{{ network }}",
"reserved_ip_ranges": "{{ reserved_ip_ranges }}",
"maximum_run_time_hours": "{{ maximum_run_time_hours }}",
"scheduling": "{{ scheduling }}",
"worker_pool_specs": [
{
"replica_count": 1,
Expand Down Expand Up @@ -471,12 +483,29 @@ def _build_job_spec(
for spec in configuration.job_spec.pop("worker_pool_specs", [])
]

timeout = Duration().FromTimedelta(
td=datetime.timedelta(
hours=configuration.job_spec["maximum_run_time_hours"]
scheduling = Scheduling()

if "scheduling" in configuration.job_spec:
scheduling_params = configuration.job_spec.pop("scheduling")
for key, value in scheduling_params.items():
# Handle if Strategy is passed as an Enum or Str
if key == "strategy":
if isinstance(value, Scheduling.Strategy):
setattr(scheduling, key, value)
else:
setattr(scheduling, key, Scheduling.Strategy[value])
else:
setattr(scheduling, key, value)

# Override "timeout" in Scheduling object if "maximum_run_time_hours" is specified
if "maximum_run_time_hours" in configuration.job_spec:
timeout = Duration()
timeout.FromTimedelta(
td=datetime.timedelta(
hours=configuration.job_spec["maximum_run_time_hours"]
)
)
)
scheduling = Scheduling(timeout=timeout)
scheduling.timeout = timeout

if "service_account_name" in configuration.job_spec:
service_account_name = configuration.job_spec.pop("service_account_name")
Expand Down
43 changes: 43 additions & 0 deletions src/integrations/prefect-gcp/tests/test_vertex_worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import datetime
import uuid
from unittest.mock import MagicMock

import pydantic
import pytest
from google.cloud.aiplatform_v1.types.custom_job import Scheduling
from google.cloud.aiplatform_v1.types.job_state import JobState
from google.protobuf.duration_pb2 import Duration
from prefect_gcp.workers.vertex import (
VertexAIWorker,
VertexAIWorkerJobConfiguration,
Expand Down Expand Up @@ -40,6 +43,10 @@ def job_config(service_account_info, gcp_credentials):
},
}
],
"scheduling": {
"strategy": "FLEX_START",
"max_wait_duration": "1800s",
},
},
)

Expand Down Expand Up @@ -156,6 +163,42 @@ async def test_successful_worker_run(self, flow_run, job_config):
status_code=0, identifier="mock_display_name"
)

async def test_params_worker_run(self, flow_run, job_config):
async with VertexAIWorker("test-pool") as worker:
# Initialize scheduling parameters
maximum_run_time_hours = job_config.job_spec["maximum_run_time_hours"]
max_wait_duration = job_config.job_spec["scheduling"]["max_wait_duration"]
timeout = Duration()
timeout.FromTimedelta(td=datetime.timedelta(hours=maximum_run_time_hours))
scheduling = Scheduling(
timeout=timeout, max_wait_duration=max_wait_duration
)

job_config.prepare_for_flow_run(flow_run, None, None)
result = await worker.run(flow_run=flow_run, configuration=job_config)

custom_job_spec = job_config.credentials.job_service_async_client.create_custom_job.call_args[
1
]["custom_job"].job_spec

# Assert scheduling parameters
assert custom_job_spec.scheduling.timeout == scheduling.timeout
assert (
custom_job_spec.scheduling.strategy == Scheduling.Strategy["FLEX_START"]
)
assert (
custom_job_spec.scheduling.max_wait_duration
== scheduling.max_wait_duration
)

assert (
job_config.credentials.job_service_async_client.get_custom_job.call_count
== 1
)
assert result == VertexAIWorkerResult(
status_code=0, identifier="mock_display_name"
)

async def test_failed_worker_run(self, flow_run, job_config):
job_config.prepare_for_flow_run(flow_run, None, None)
error_msg = "something went kablooey"
Expand Down