Skip to content

Commit 7cb9ea9

Browse files
feat: add Switch transpiler installer for Lakebridge integration
Implement SwitchInstaller to integrate Switch transpiler with Lakebridge: - Install Switch package to local virtual environment and deploy to workspace - Create and manage Databricks job for Switch transpilation - Configure Switch resources (catalog, schema, volume) interactively - Support job-level parameters with JobParameterDefinition for flexibility - Handle installation state and job lifecycle management - Add comprehensive test suite covering installation, job management, and configuration
1 parent d0c63c3 commit 7cb9ea9

File tree

2 files changed

+74
-90
lines changed

2 files changed

+74
-90
lines changed

src/databricks/labs/lakebridge/transpiler/installers.py

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository
2727
from databricks.sdk import WorkspaceClient
2828
from databricks.sdk.errors import InvalidParameterValue, NotFound
29-
from databricks.sdk.service.jobs import JobSettings, NotebookTask, Source, Task
29+
from databricks.sdk.service.jobs import JobParameterDefinition, JobSettings, NotebookTask, Source, Task
3030

3131
logger = logging.getLogger(__name__)
3232

@@ -643,7 +643,6 @@ def _deploy_workspace(self, switch_package_dir: Path) -> None:
643643
"""Deploy Switch package to workspace from site-packages."""
644644
try:
645645
logger.info("Deploying Switch package to workspace...")
646-
647646
remote_path = f"{self._TRANSPILER_ID}/databricks"
648647
self._upload_directory(switch_package_dir, remote_path)
649648
logger.info("Switch workspace deployment completed")
@@ -673,54 +672,47 @@ def _upload_directory(self, local_path: Path, remote_prefix: str) -> None:
673672
def _setup_job(self) -> None:
674673
"""Create Switch job if not exists."""
675674
install_state = InstallState.from_installation(self._installation)
676-
677-
if self._has_valid_job(install_state):
678-
logger.info("Switch job already exists")
679-
return
680-
681-
logger.info("Creating Switch transpiler job...")
675+
existing_job_id = self._get_existing_job_id(install_state)
676+
logger.info("Setting up Switch job in workspace...")
682677
try:
683-
job_id = self._create_or_update_switch_job(install_state)
678+
job_id = self._create_or_update_switch_job(existing_job_id)
679+
install_state.jobs[self._INSTALL_STATE_KEY] = job_id
684680
install_state.save()
685-
job_url = f"{self._workspace_client.config.host}#job/{job_id}"
681+
job_url = f"{self._workspace_client.config.host}/jobs/{job_id}"
686682
logger.info(f"Switch job created/updated: {job_url}")
687683
except (RuntimeError, ValueError, InvalidParameterValue) as e:
688-
logger.error(f"Failed to create Switch job: {e}")
684+
logger.error(f"Failed to create/update Switch job: {e}")
689685

690-
def _has_valid_job(self, install_state: InstallState) -> bool:
691-
"""Check if Switch job exists and is valid in workspace."""
686+
def _get_existing_job_id(self, install_state: InstallState) -> str | None:
687+
"""Check if Switch job already exists in workspace and return its job_id."""
692688
if self._INSTALL_STATE_KEY not in install_state.jobs:
693-
return False
689+
return None
694690
try:
695-
job_id = int(install_state.jobs[self._INSTALL_STATE_KEY])
696-
self._workspace_client.jobs.get(job_id)
697-
return True
691+
job_id = install_state.jobs[self._INSTALL_STATE_KEY]
692+
self._workspace_client.jobs.get(int(job_id))
693+
return job_id
698694
except (InvalidParameterValue, NotFound, ValueError):
699-
return False
695+
return None
700696

701-
def _create_or_update_switch_job(self, install_state: InstallState) -> str:
697+
def _create_or_update_switch_job(self, job_id: str | None) -> str:
702698
"""Create or update Switch job"""
703-
# Check for existing job and update
704-
if self._INSTALL_STATE_KEY in install_state.jobs:
699+
job_settings = self._get_switch_job_settings()
700+
701+
# Try to update existing job
702+
if job_id:
705703
try:
706-
job_id = int(install_state.jobs[self._INSTALL_STATE_KEY])
707704
logger.info(f"Updating Switch job: {job_id}")
708-
job_settings = self._get_switch_job_settings()
709-
self._workspace_client.jobs.reset(job_id, JobSettings(**job_settings))
710-
return str(job_id)
711-
except InvalidParameterValue:
712-
del install_state.jobs[self._INSTALL_STATE_KEY]
705+
self._workspace_client.jobs.reset(int(job_id), JobSettings(**job_settings))
706+
return job_id
707+
except (ValueError, InvalidParameterValue):
713708
logger.warning("Previous Switch job not found, creating new one")
714709

715710
# Create new job
716-
logger.info("Creating new Switch job configuration")
717-
job_settings = self._get_switch_job_settings()
711+
logger.info("Creating new Switch job")
718712
new_job = self._workspace_client.jobs.create(**job_settings)
719-
assert new_job.job_id is not None
720-
721-
# Save to InstallState
722-
install_state.jobs[self._INSTALL_STATE_KEY] = str(new_job.job_id)
723-
return str(new_job.job_id)
713+
new_job_id = str(new_job.job_id)
714+
assert new_job_id is not None
715+
return new_job_id
724716

725717
def _get_switch_job_settings(self) -> dict:
726718
"""Build job settings for Switch transpiler using serverless compute"""
@@ -736,25 +728,21 @@ def _get_switch_job_settings(self) -> dict:
736728
task_key="run_transpilation",
737729
notebook_task=NotebookTask(
738730
notebook_path=notebook_path,
739-
base_parameters=self._get_switch_parameters_from_config(),
740731
source=Source.WORKSPACE,
741732
),
742733
disable_auto_optimization=True, # To disable retries on failure
743734
)
744735

745736
return {
746737
"name": job_name,
747-
"tags": {"created_by": user_name, "version": f"v{version}"},
738+
"tags": {"created_by": user_name, "switch_version": f"v{version}"},
748739
"tasks": [task],
740+
"parameters": self._get_switch_job_parameters(),
749741
"max_concurrent_runs": 100, # Allow simultaneous transpilations
750742
}
751743

752-
def _get_switch_parameters_from_config(self) -> dict:
753-
"""Extract Switch parameters from installed config.yml.
754-
755-
Raises:
756-
ValueError: If Switch config.yml is not found (indicates incomplete installation)
757-
"""
744+
def _get_switch_job_parameters(self) -> list[JobParameterDefinition]:
745+
"""Build job-level parameter definitions from installed config.yml."""
758746
configs = self._transpiler_repository.all_transpiler_configs()
759747
config = configs.get(self._TRANSPILER_ID)
760748

@@ -786,7 +774,7 @@ def _get_switch_parameters_from_config(self) -> dict:
786774

787775
parameters[flag] = default
788776

789-
return parameters
777+
return [JobParameterDefinition(name=key, default=value) for key, value in parameters.items()]
790778

791779
def _configure_resources(self) -> None:
792780
"""Configure Switch resources (catalog, schema, volume) if not configured."""

tests/unit/transpiler/test_installers.py

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ def test_java_version_parse_missing() -> None:
121121
class FriendOfSwitchInstaller(SwitchInstaller):
122122
"""A friend class to access protected methods for testing purposes."""
123123

124-
def has_valid_job(self, install_state: Any) -> bool:
125-
return self._has_valid_job(install_state)
124+
def get_existing_job_id(self, install_state: Any) -> str | None:
125+
return self._get_existing_job_id(install_state)
126126

127-
def create_or_update_switch_job(self, install_state: Any) -> str:
128-
return self._create_or_update_switch_job(install_state)
127+
def create_or_update_switch_job(self, job_id: str | None) -> str:
128+
return self._create_or_update_switch_job(job_id)
129129

130-
def get_switch_parameters_from_config(self) -> dict:
131-
return self._get_switch_parameters_from_config()
130+
def get_switch_job_parameters(self) -> list:
131+
return self._get_switch_job_parameters()
132132

133133
def prompt_for_switch_resources(self) -> tuple[str, str, str]:
134134
return self._prompt_for_switch_resources()
@@ -343,20 +343,20 @@ def test_get_configured_resources(
343343
@pytest.mark.parametrize(
344344
("jobs_state", "get_side_effect", "expected"),
345345
(
346-
pytest.param({}, None, False, id="no_job_in_state"),
347-
pytest.param({"Switch": "12345"}, None, True, id="valid_job_exists"),
348-
pytest.param({"Switch": "99999"}, NotFound("Job not found"), False, id="job_not_found"),
349-
pytest.param({"Switch": "invalid"}, ValueError("Invalid job ID"), False, id="invalid_job_id"),
346+
pytest.param({}, None, None, id="no_job_in_state"),
347+
pytest.param({"Switch": "12345"}, None, "12345", id="valid_job_exists"),
348+
pytest.param({"Switch": "99999"}, NotFound("Job not found"), None, id="job_not_found"),
349+
pytest.param({"Switch": "invalid"}, ValueError("Invalid job ID"), None, id="invalid_job_id"),
350350
),
351351
)
352-
def test_has_valid_job(
352+
def test_get_existing_job_id(
353353
self,
354354
jobs_state: dict,
355355
get_side_effect: Exception | None,
356-
expected: bool,
356+
expected: str | None,
357357
tmp_path: Path,
358358
) -> None:
359-
"""Test _has_valid_job checks job validity in workspace."""
359+
"""Test _get_existing_job_id returns job_id if valid, None otherwise."""
360360
install_state = Mock()
361361
install_state.jobs = jobs_state
362362

@@ -367,19 +367,19 @@ def test_has_valid_job(
367367
# Use friend class to access protected method
368368
repository = TranspilerRepository(tmp_path)
369369
friend_installer = FriendOfSwitchInstaller(repository, mock_ws, Mock())
370-
result = friend_installer.has_valid_job(install_state)
370+
result = friend_installer.get_existing_job_id(install_state)
371371

372372
assert result == expected
373373
if jobs_state and "Switch" in jobs_state and not get_side_effect:
374374
mock_ws.jobs.get.assert_called_once_with(int(jobs_state["Switch"]))
375375

376376
@pytest.mark.parametrize(
377-
("initial_jobs", "reset_side_effect", "expected_job_id", "expect_create", "expect_reset"),
377+
("initial_job_id", "reset_side_effect", "expected_job_id", "expect_create", "expect_reset"),
378378
(
379-
pytest.param({}, None, "12345", True, False, id="new_job_creation"),
380-
pytest.param({"Switch": "67890"}, None, "67890", False, True, id="existing_job_update"),
379+
pytest.param(None, None, "12345", True, False, id="new_job_creation"),
380+
pytest.param("67890", None, "67890", False, True, id="existing_job_update"),
381381
pytest.param(
382-
{"Switch": "99999"},
382+
"99999",
383383
InvalidParameterValue("Job not found"),
384384
"12345",
385385
True,
@@ -389,26 +389,18 @@ def test_has_valid_job(
389389
),
390390
)
391391
@patch.object(SwitchInstaller, "_get_switch_job_settings")
392-
@patch("databricks.labs.lakebridge.transpiler.installers.InstallState")
393392
def test_job_creation(
394393
self,
395-
mock_install_state_class: Mock,
396394
mock_get_settings: Mock,
397-
initial_jobs: dict,
395+
initial_job_id: str | None,
398396
reset_side_effect: Exception | None,
399397
expected_job_id: str,
400398
expect_create: bool,
401399
expect_reset: bool,
402-
installer: SwitchInstaller,
403400
tmp_path: Path,
404401
) -> None:
405402
"""Test Switch job creation and update scenarios."""
406403
# Setup
407-
mock_install_state = Mock()
408-
mock_install_state.jobs = dict(initial_jobs)
409-
mock_install_state.save = Mock()
410-
mock_install_state_class.from_installation.return_value = mock_install_state
411-
412404
mock_get_settings.return_value = {"name": "test_job"}
413405

414406
mock_job = Mock()
@@ -428,14 +420,15 @@ def test_job_creation(
428420
test_repository = TranspilerRepository(tmp_path)
429421
mock_installation = Mock()
430422
friend_installer = FriendOfSwitchInstaller(test_repository, mock_ws, mock_installation)
431-
result = friend_installer.create_or_update_switch_job(mock_install_state)
423+
result = friend_installer.create_or_update_switch_job(initial_job_id)
432424

433425
# Assert
434426
assert result == expected_job_id
435427

436428
if expect_reset:
437429
if reset_side_effect:
438-
mock_jobs.reset.assert_called_once_with(int(initial_jobs["Switch"]), JobSettings(name="test_job"))
430+
assert initial_job_id is not None
431+
mock_jobs.reset.assert_called_once_with(int(initial_job_id), JobSettings(name="test_job"))
439432
else:
440433
mock_jobs.reset.assert_called_once_with(int(expected_job_id), JobSettings(name="test_job"))
441434
else:
@@ -446,9 +439,7 @@ def test_job_creation(
446439
else:
447440
mock_jobs.create.assert_not_called()
448441

449-
assert mock_install_state.jobs["Switch"] == expected_job_id
450-
451-
def test_get_switch_parameters_handles_various_default_values(
442+
def test_get_switch_job_parameters_handles_various_default_values(
452443
self, installer: SwitchInstaller, tmp_path: Path
453444
) -> None:
454445
"""Test that different default values in config are correctly converted."""
@@ -469,20 +460,25 @@ def test_get_switch_parameters_handles_various_default_values(
469460
friend_installer = FriendOfSwitchInstaller(test_repository, mock_ws, mock_installation)
470461

471462
with patch.object(test_repository, "all_transpiler_configs", return_value={"switch": mock_config}):
472-
params = friend_installer.get_switch_parameters_from_config()
473-
474-
assert "input_dir" in params
475-
assert "output_dir" in params
476-
assert "result_catalog" in params
477-
assert "result_schema" in params
478-
assert "builtin_prompt" in params
479-
assert params["flag1"] == ""
480-
assert params["flag2"] == "123"
481-
assert params["flag3"] == ""
482-
assert params["flag4"] == "value"
483-
assert params["flag5"] == "3.14"
484-
485-
def test_get_switch_parameters_raises_when_config_missing(self, installer: SwitchInstaller, tmp_path: Path) -> None:
463+
params = friend_installer.get_switch_job_parameters()
464+
465+
# Convert list to dict for easier testing
466+
params_dict = {p.name: p.default for p in params}
467+
468+
assert "input_dir" in params_dict
469+
assert "output_dir" in params_dict
470+
assert "result_catalog" in params_dict
471+
assert "result_schema" in params_dict
472+
assert "builtin_prompt" in params_dict
473+
assert params_dict["flag1"] == ""
474+
assert params_dict["flag2"] == "123"
475+
assert params_dict["flag3"] == ""
476+
assert params_dict["flag4"] == "value"
477+
assert params_dict["flag5"] == "3.14"
478+
479+
def test_get_switch_job_parameters_raises_when_config_missing(
480+
self, installer: SwitchInstaller, tmp_path: Path
481+
) -> None:
486482
"""Test that ValueError is raised when Switch config is not found."""
487483
test_repository = TranspilerRepository(tmp_path)
488484
mock_ws = Mock()
@@ -491,7 +487,7 @@ def test_get_switch_parameters_raises_when_config_missing(self, installer: Switc
491487

492488
with patch.object(test_repository, "all_transpiler_configs", return_value={}):
493489
with pytest.raises(ValueError, match="Switch config.yml not found"):
494-
friend_installer.get_switch_parameters_from_config()
490+
friend_installer.get_switch_job_parameters()
495491

496492
@patch("databricks.labs.lakebridge.transpiler.installers.ResourceConfigurator")
497493
@patch("databricks.labs.lakebridge.transpiler.installers.CatalogOperations")

0 commit comments

Comments
 (0)