Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integrations/dspy/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ requires-python = ">=3.10"
dependencies = [
"dspy>=2.6.27",
"databricks-sdk>=0.58.0",
"mlflow>=3.0.0",
]

[project.optional-dependencies]
Expand Down
97 changes: 97 additions & 0 deletions integrations/dspy/src/databricks_dspy/clients/databricks_lm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,82 @@
import logging
from typing import Optional

import dspy
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import PtEndpointCoreConfig, PtServedModel

logger = logging.getLogger(__name__)


class DatabricksLM(dspy.LM):
def __init__(
self,
model: str,
workspace_client: Optional[WorkspaceClient] = None,
create_pt_endpoint: bool = False,
pt_entity: Optional[PtServedModel] = None,
**kwargs,
):
"""Subclass of `dspy.LM` for compatibility with Databricks.

Args:
model: The model to use. Must start with 'databricks/'.
workspace_client: The workspace client to use. If not provided, a new one will be
created with default credentials from the environment.
create_pt_endpoint: Whether to create a provisioned throughput endpoint to make LM
calls.
pt_entity: The entity to serve, only used when `create_pt_endpoint` is True.

Example 1: Use a Databricks model with preconfigured workspace client.

```python
import dspy
import databricks_dspy
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
lm = databricks_dspy.DatabricksLM(
"databricks/databricks-llama-4-maverick",
workspace_client=w,
)
dspy.configure(lm=lm)

predict = dspy.Predict("q->a")
print(predict(q="why did a chicken cross the kitchen?"))
```

Example 2: Create a provisioned throughput endpoint for a Databricks model.

```python
import dspy
import databricks_dspy
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import PtServedModel

w = WorkspaceClient()
entity = PtServedModel(
entity_name="system.ai.llama-4-maverick",
entity_version="1",
provisioned_model_units=50,
)
lm = databricks_dspy.DatabricksLM(
"databricks/provisioned-llama-4-maverick",
workspace_client=w,
create_pt_endpoint=True,
pt_entity=entity,
)
dspy.configure(lm=lm)

predict = dspy.Predict("q->a")
print(predict(q="why did a chicken cross the kitchen?"))
```
"""
if not model.startswith("databricks/"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to also check if we're in a databricks notebook env and log a warning / exit if not?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good call!

It's quite a hassle to validate Databricks notebook environment though: https://github.com/mlflow/mlflow/blob/a0e03e1004989740f10b101bd91582fcda733749/mlflow/utils/databricks_utils.py#L184, which cross references about 8 methods. So I am taking a workaround to use a try-except block to raise an error message to prompt users to use databricks notebookm which should be safer and more generic.

raise ValueError(
"`model` must start with 'databricks/' when using `DatabricksLM`, "
"e.g. dspy.LM('databricks/databricks-llama-4-maverick')"
)

super().__init__(model=model, **kwargs)

if workspace_client:
Expand All @@ -28,6 +94,37 @@ def __init__(
"for how to set up the authentication."
) from e

self.create_pt_endpoint = create_pt_endpoint
self.pt_entity = pt_entity

if create_pt_endpoint:
self.pt_endpoint = self._create_pt_endpoint()

def _create_pt_endpoint(self):
# Create the provisioned throughput endpoint configuration
config = PtEndpointCoreConfig(served_entities=[self.pt_entity])

model_name_without_databricks_prefix = self.model[len("databricks/") :]
# Create the provisioned throughput endpoint
w = self.workspace_client
try:
return w.serving_endpoints.create_provisioned_throughput_endpoint_and_wait(
name=model_name_without_databricks_prefix,
config=config,
)
except Exception as e:
raise RuntimeError(
f"Failed to create provisioned throughput endpoint: {e}\n\n"
"`create_pt_endpoint=True` is only supported in Databricks notebooks now."
) from e

def tear_down(self):
if not self.create_pt_endpoint:
logger.warning("`tear_down` is an no-op when `create_pt_endpoint` is False.")
return

self.workspace_client.serving_endpoints.delete(self.pt_endpoint.name)

def forward(self, **kwargs):
return super().forward(
headers=self.workspace_client.config.authenticate(),
Expand Down
24 changes: 21 additions & 3 deletions integrations/dspy/tests/unit_tests/clients/test_databricks_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_forward_invokes_authenticate():
mock_ws_client.config.authenticate.return_value = {"Authorization": "Bearer token"}
mock_ws_client.config.host = "https://test-host"
mock_ws_client.current_user.me.return_value = "some valid value"
lm = DatabricksLM(model="test-model", workspace_client=mock_ws_client)
lm = DatabricksLM(model="databricks/test-model", workspace_client=mock_ws_client)

with patch("databricks_dspy.clients.databricks_lm.dspy.LM.forward") as mock_super_forward:
# Call the LM (`DatabricksLM.__call__` will call `forward`)
Expand All @@ -34,7 +34,7 @@ def test_valid_credentials():
mock_ws.current_user = mock_current_user
MockWSClient.return_value = mock_ws

DatabricksLM(model="test-model")
DatabricksLM(model="databricks/test-model")
mock_current_user.me.assert_called_once()


Expand All @@ -48,5 +48,23 @@ def test_invalid_credentials_raise_error():
MockWSClient.return_value = mock_ws

with pytest.raises(RuntimeError, match="Failed to validate databricks credentials"):
DatabricksLM(model="test-model")
DatabricksLM(model="databricks/test-model")
mock_current_user.me.assert_called_once()


def test_create_pt_endpoint_failed():
with patch("databricks_dspy.clients.databricks_lm.WorkspaceClient") as MockWSClient: # noqa: E501
mock_ws = MagicMock()
# Simulate endpoint creation failure.
mock_ws.serving_endpoints.create_provisioned_throughput_endpoint_and_wait.side_effect = (
Exception("PT endpoint creation failed")
) # noqa: E501
MockWSClient.return_value = mock_ws

pt_entity = MagicMock()
with pytest.raises(RuntimeError) as e:
DatabricksLM(
model="databricks/test-model", create_pt_endpoint=True, pt_entity=pt_entity
)
assert "Failed to create provisioned throughput endpoint" in str(e)
assert "`create_pt_endpoint=True` is only supported in Databricks notebooks now." in str(e)