Skip to content

Commit eb166cf

Browse files
committed
Add instance pool id to make_job to speed up testing
1 parent 6ec8c0d commit eb166cf

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

src/databricks/labs/pytester/fixtures/compute.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def make_job(
180180
* [DEPRECATED: Use `path` instead] `notebook_path` (str, optional): The path to the notebook. If not provided, a random notebook will be created.
181181
* `content` (str | bytes, optional): The content of the notebook or file used in the job. If not provided, default content of `make_notebook` will be used.
182182
* `task_type` (type[NotebookTask] | type[SparkPythonTask], optional): The type of task. If not provides, `type[NotebookTask]` will be used.
183+
* `instance_pool_id` (str, optional): The instance pool id to add to the job cluster. If not provided, no instance pool will be used.
183184
* `spark_conf` (dict, optional): The Spark configuration of the job. If not provided, Spark configuration is not explicitly set.
184185
* `libraries` (list, optional): The list of libraries to install on the job.
185186
* `tags` (list[str], optional): A list of job tags. If not provided, no additional tags will be set on the job.
@@ -202,6 +203,7 @@ def create(
202203
content: str | bytes | None = None,
203204
task_type: type[NotebookTask] | type[SparkPythonTask] = NotebookTask,
204205
spark_conf: dict[str, str] | None = None,
206+
instance_pool_id: str | None = None,
205207
libraries: list[Library] | None = None,
206208
tags: dict[str, str] | None = None,
207209
tasks: list[Task] | None = None,
@@ -223,13 +225,17 @@ def create(
223225
tags = tags or {}
224226
tags["RemoveAfter"] = tags.get("RemoveAfter", watchdog_remove_after)
225227
if not tasks:
228+
node_type_id = None
229+
if instance_pool_id is None:
230+
node_type_id = ws.clusters.select_node_type(local_disk=True, min_memory_gb=16)
226231
task = Task(
227232
task_key=make_random(4),
228233
description=make_random(4),
229234
new_cluster=ClusterSpec(
230235
num_workers=1,
231-
node_type_id=ws.clusters.select_node_type(local_disk=True, min_memory_gb=16),
236+
node_type_id=node_type_id,
232237
spark_version=ws.clusters.select_spark_version(latest=True),
238+
instance_pool_id=instance_pool_id,
233239
spark_conf=spark_conf,
234240
),
235241
libraries=libraries,

tests/integration/fixtures/test_compute.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ def test_instance_pool(make_instance_pool):
2222
logger.info(f"created {make_instance_pool()}")
2323

2424

25-
def test_job(ws: WorkspaceClient, make_job) -> None:
26-
job = make_job()
25+
def test_job(ws: WorkspaceClient, make_job, env_or_skip) -> None:
26+
job = make_job(instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"))
2727
run = ws.jobs.run_now(job.job_id)
2828
ws.jobs.wait_get_run_job_terminated_or_skipped(run_id=run.run_id)
2929
run_state = ws.jobs.get_run(run_id=run.run_id).state
3030
assert run_state is not None and run_state.result_state == RunResultState.SUCCESS
3131

3232

33-
def test_job_with_spark_python_task(ws: WorkspaceClient, make_job) -> None:
34-
job = make_job(task_type=SparkPythonTask)
33+
def test_job_with_spark_python_task(ws: WorkspaceClient, make_job, env_or_skip) -> None:
34+
job = make_job(task_type=SparkPythonTask, instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"))
3535
run = ws.jobs.run_now(job.job_id)
3636
ws.jobs.wait_get_run_job_terminated_or_skipped(run_id=run.run_id)
3737
run_state = ws.jobs.get_run(run_id=run.run_id).state

tests/unit/fixtures/test_compute.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ def test_make_job_with_spark_python_task() -> None:
7777
assert workspace_path.read_text() == "print(3)"
7878

7979

80+
def test_make_job_with_instance_pool_id() -> None:
81+
_, job = call_stateful(make_job, instance_pool_id="test")
82+
tasks = job.settings.tasks
83+
assert len(tasks) == 1
84+
assert tasks[0].new_cluster.instance_pool_id == "test"
85+
86+
8087
def test_make_job_with_spark_conf() -> None:
8188
_, job = call_stateful(make_job, spark_conf={"value": "test"})
8289
tasks = job.settings.tasks

0 commit comments

Comments
 (0)