diff --git a/integrations/dspy/pyproject.toml b/integrations/dspy/pyproject.toml index 406090e61..1c8c7afc0 100644 --- a/integrations/dspy/pyproject.toml +++ b/integrations/dspy/pyproject.toml @@ -11,6 +11,7 @@ requires-python = ">=3.10" dependencies = [ "dspy>=2.6.27", "databricks-sdk>=0.58.0", + "mlflow>=3.0.0", ] [project.optional-dependencies] diff --git a/integrations/dspy/src/databricks_dspy/clients/databricks_lm.py b/integrations/dspy/src/databricks_dspy/clients/databricks_lm.py index 528b6398a..d3e3542d6 100644 --- a/integrations/dspy/src/databricks_dspy/clients/databricks_lm.py +++ b/integrations/dspy/src/databricks_dspy/clients/databricks_lm.py @@ -1,7 +1,11 @@ +import logging from typing import Optional import dspy from databricks.sdk import WorkspaceClient +from databricks.sdk.service.serving import PtEndpointCoreConfig, PtServedModel + +logger = logging.getLogger(__name__) class DatabricksLM(dspy.LM): @@ -9,8 +13,70 @@ def __init__( self, model: str, workspace_client: Optional[WorkspaceClient] = None, + create_pt_endpoint: bool = False, + pt_entity: Optional[PtServedModel] = None, **kwargs, ): + """Subclass of `dspy.LM` for compatibility with Databricks. + + Args: + model: The model to use. Must start with 'databricks/'. + workspace_client: The workspace client to use. If not provided, a new one will be + created with default credentials from the environment. + create_pt_endpoint: Whether to create a provisioned throughput endpoint to make LM + calls. + pt_entity: The entity to serve, only used when `create_pt_endpoint` is True. + + Example 1: Use a Databricks model with preconfigured workspace client. + + ```python + import dspy + import databricks_dspy + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + lm = databricks_dspy.DatabricksLM( + "databricks/databricks-llama-4-maverick", + workspace_client=w, + ) + dspy.configure(lm=lm) + + predict = dspy.Predict("q->a") + print(predict(q="why did a chicken cross the kitchen?")) + ``` + + Example 2: Create a provisioned throughput endpoint for a Databricks model. + + ```python + import dspy + import databricks_dspy + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.serving import PtServedModel + + w = WorkspaceClient() + entity = PtServedModel( + entity_name="system.ai.llama-4-maverick", + entity_version="1", + provisioned_model_units=50, + ) + lm = databricks_dspy.DatabricksLM( + "databricks/provisioned-llama-4-maverick", + workspace_client=w, + create_pt_endpoint=True, + pt_entity=entity, + ) + dspy.configure(lm=lm) + + predict = dspy.Predict("q->a") + print(predict(q="why did a chicken cross the kitchen?")) + ``` + """ + if not model.startswith("databricks/"): + raise ValueError( + "`model` must start with 'databricks/' when using `DatabricksLM`, " + "e.g. dspy.LM('databricks/databricks-llama-4-maverick')" + ) + super().__init__(model=model, **kwargs) if workspace_client: @@ -28,6 +94,37 @@ def __init__( "for how to set up the authentication." ) from e + self.create_pt_endpoint = create_pt_endpoint + self.pt_entity = pt_entity + + if create_pt_endpoint: + self.pt_endpoint = self._create_pt_endpoint() + + def _create_pt_endpoint(self): + # Create the provisioned throughput endpoint configuration + config = PtEndpointCoreConfig(served_entities=[self.pt_entity]) + + model_name_without_databricks_prefix = self.model[len("databricks/") :] + # Create the provisioned throughput endpoint + w = self.workspace_client + try: + return w.serving_endpoints.create_provisioned_throughput_endpoint_and_wait( + name=model_name_without_databricks_prefix, + config=config, + ) + except Exception as e: + raise RuntimeError( + f"Failed to create provisioned throughput endpoint: {e}\n\n" + "`create_pt_endpoint=True` is only supported in Databricks notebooks now." + ) from e + + def tear_down(self): + if not self.create_pt_endpoint: + logger.warning("`tear_down` is an no-op when `create_pt_endpoint` is False.") + return + + self.workspace_client.serving_endpoints.delete(self.pt_endpoint.name) + def forward(self, **kwargs): return super().forward( headers=self.workspace_client.config.authenticate(), diff --git a/integrations/dspy/tests/unit_tests/clients/test_databricks_lm.py b/integrations/dspy/tests/unit_tests/clients/test_databricks_lm.py index 9323e0f6e..5124d5836 100644 --- a/integrations/dspy/tests/unit_tests/clients/test_databricks_lm.py +++ b/integrations/dspy/tests/unit_tests/clients/test_databricks_lm.py @@ -10,7 +10,7 @@ def test_forward_invokes_authenticate(): mock_ws_client.config.authenticate.return_value = {"Authorization": "Bearer token"} mock_ws_client.config.host = "https://test-host" mock_ws_client.current_user.me.return_value = "some valid value" - lm = DatabricksLM(model="test-model", workspace_client=mock_ws_client) + lm = DatabricksLM(model="databricks/test-model", workspace_client=mock_ws_client) with patch("databricks_dspy.clients.databricks_lm.dspy.LM.forward") as mock_super_forward: # Call the LM (`DatabricksLM.__call__` will call `forward`) @@ -34,7 +34,7 @@ def test_valid_credentials(): mock_ws.current_user = mock_current_user MockWSClient.return_value = mock_ws - DatabricksLM(model="test-model") + DatabricksLM(model="databricks/test-model") mock_current_user.me.assert_called_once() @@ -48,5 +48,23 @@ def test_invalid_credentials_raise_error(): MockWSClient.return_value = mock_ws with pytest.raises(RuntimeError, match="Failed to validate databricks credentials"): - DatabricksLM(model="test-model") + DatabricksLM(model="databricks/test-model") mock_current_user.me.assert_called_once() + + +def test_create_pt_endpoint_failed(): + with patch("databricks_dspy.clients.databricks_lm.WorkspaceClient") as MockWSClient: # noqa: E501 + mock_ws = MagicMock() + # Simulate endpoint creation failure. + mock_ws.serving_endpoints.create_provisioned_throughput_endpoint_and_wait.side_effect = ( + Exception("PT endpoint creation failed") + ) # noqa: E501 + MockWSClient.return_value = mock_ws + + pt_entity = MagicMock() + with pytest.raises(RuntimeError) as e: + DatabricksLM( + model="databricks/test-model", create_pt_endpoint=True, pt_entity=pt_entity + ) + assert "Failed to create provisioned throughput endpoint" in str(e) + assert "`create_pt_endpoint=True` is only supported in Databricks notebooks now." in str(e)