Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integrations/langchain/src/databricks_langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MCPServer,
)
from databricks_langchain.store import AsyncDatabricksStore, DatabricksStore
from databricks_langchain.uc_volume_tool import UCVolumeTool
from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
from databricks_langchain.vectorstores import DatabricksVectorSearch

Expand All @@ -40,6 +41,7 @@
"DatabricksStore",
"DatabricksVectorSearch",
"GenieAgent",
"UCVolumeTool",
"VectorSearchRetrieverTool",
"UCFunctionToolkit",
"UnityCatalogTool",
Expand Down
57 changes: 57 additions & 0 deletions integrations/langchain/src/databricks_langchain/uc_volume_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Type

from databricks_ai_bridge.uc_volume_tool import (
UCVolumeToolInput,
UCVolumeToolMixin,
uc_volume_tool_trace,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator


class UCVolumeTool(BaseTool, UCVolumeToolMixin):
"""
A LangChain tool for reading files from a Databricks Unity Catalog Volume.

This class integrates with Databricks UC Volumes and provides a convenient interface
for building a file reading tool for agents.

Example:

.. code-block:: python

from databricks_langchain import UCVolumeTool, ChatDatabricks

vol_tool = UCVolumeTool(
volume_name="catalog.schema.my_documents",
tool_name="document_reader",
tool_description="Reads files from the company documents volume.",
)

# Test locally
vol_tool.invoke("reports/q4_summary.txt")

# Bind to LLM
llm = ChatDatabricks(endpoint="databricks-claude-sonnet-4-5")
llm_with_tools = llm.bind_tools([vol_tool])
llm_with_tools.invoke("Read the Q4 summary from reports/q4_summary.txt")
"""

# BaseTool requires 'name' and 'description' fields; populated in _validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
description: str = Field(default="", description="The description of the tool")
args_schema: Type[BaseModel] = UCVolumeToolInput

@model_validator(mode="after")
def _validate_tool_inputs(self):
self.name = self._get_tool_name()
self.description = self.tool_description or self._get_default_tool_description()
if not self.workspace_client:
from databricks.sdk import WorkspaceClient

self.workspace_client = WorkspaceClient()
return self

@uc_volume_tool_trace
def _run(self, file_path: str, **kwargs) -> str:
return self._read_file(file_path)
120 changes: 120 additions & 0 deletions integrations/langchain/tests/unit_tests/test_uc_volume_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import io
from typing import Any
from unittest.mock import MagicMock

import mlflow
import pytest
from databricks_ai_bridge.test_utils.uc_volume import ( # noqa: F401
SAMPLE_FILE_CONTENT,
VOLUME_NAME,
mock_workspace_client,
)
from langchain_core.tools import BaseTool
from mlflow.entities import SpanType

from databricks_langchain import UCVolumeTool


def init_volume_tool(
volume_name: str = VOLUME_NAME,
tool_name: str | None = None,
tool_description: str | None = None,
**kwargs: Any,
) -> UCVolumeTool:
kwargs.update(
{
"volume_name": volume_name,
"tool_name": tool_name,
"tool_description": tool_description,
}
)
return UCVolumeTool(**kwargs)


class TestInit:
def test_init_is_base_tool(self):
tool = init_volume_tool()
assert isinstance(tool, BaseTool)

def test_init_with_custom_name_and_description(self):
tool = init_volume_tool(tool_name="my_reader", tool_description="Reads docs")
assert tool.name == "my_reader"
assert tool.description == "Reads docs"

def test_init_default_name_from_volume_name(self):
tool = init_volume_tool()
assert tool.name == VOLUME_NAME.replace(".", "__")

def test_init_default_description(self):
tool = init_volume_tool()
assert VOLUME_NAME in tool.description
assert "Reads files" in tool.description


class TestInvoke:
def test_invoke_returns_file_content(self):
tool = init_volume_tool()
result = tool.invoke("reports/q4.txt")
assert result == SAMPLE_FILE_CONTENT

def test_invoke_with_dict_input(self):
tool = init_volume_tool()
input_dict: dict[str, Any] = {"file_path": "reports/q4.txt"}
result = tool.invoke(input_dict)
assert result == SAMPLE_FILE_CONTENT

def test_invoke_binary_file_returns_error(self, mock_workspace_client):
mock_resp = MagicMock()
mock_resp.contents = io.BytesIO(b"\x80\x81\x82\x83")
mock_workspace_client.files.download.return_value = mock_resp

tool = init_volume_tool()
result = tool.invoke("image.png")
assert "binary file" in result

def test_invoke_empty_path_returns_error(self):
tool = init_volume_tool()
input_dict: dict[str, Any] = {"file_path": ""}
result = tool.invoke(input_dict)
assert "Error" in result


class TestToolNameGeneration:
def test_default_tool_name(self):
tool = init_volume_tool(volume_name="cat.schema.vol")
assert tool.name == "cat__schema__vol"

@pytest.mark.parametrize("tool_name", [None, "valid_tool_name", "test_tool"])
def test_valid_tool_names(self, tool_name):
tool = init_volume_tool(tool_name=tool_name)
assert tool.tool_name == tool_name
if tool_name:
assert tool.name == tool_name

@pytest.mark.parametrize("tool_name", ["test.tool.name", "tool&name"])
def test_invalid_tool_names(self, tool_name):
with pytest.raises(ValueError):
init_volume_tool(tool_name=tool_name)


class TestArgsSchema:
def test_args_schema_has_file_path(self):
tool = init_volume_tool()
assert "file_path" in tool.args_schema.model_fields
assert tool.args_schema.model_fields["file_path"].description is not None


class TestTracing:
def test_tracing_with_default_name(self):
tool = init_volume_tool()
tool._run("reports/q4.txt")
trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
spans = trace.search_spans(name=VOLUME_NAME, span_type=SpanType.TOOL)
assert len(spans) == 1

def test_tracing_with_custom_name(self):
tool = init_volume_tool(tool_name="my_reader")
tool._run("reports/q4.txt")
trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
spans = trace.search_spans(name="my_reader", span_type=SpanType.TOOL)
assert len(spans) == 1
2 changes: 2 additions & 0 deletions integrations/openai/src/databricks_openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from unitycatalog.ai.openai.toolkit import UCFunctionToolkit

from databricks_openai.mcp_server_toolkit import McpServerToolkit, ToolInfo
from databricks_openai.uc_volume_tool import UCVolumeTool
from databricks_openai.utils.clients import AsyncDatabricksOpenAI, DatabricksOpenAI
from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool

# Expose all integrations to users under databricks-openai
__all__ = [
"UCVolumeTool",
"VectorSearchRetrieverTool",
"UCFunctionToolkit",
"DatabricksFunctionClient",
Expand Down
94 changes: 94 additions & 0 deletions integrations/openai/src/databricks_openai/uc_volume_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Any, Optional

from databricks_ai_bridge.uc_volume_tool import (
UCVolumeToolInput,
UCVolumeToolMixin,
uc_volume_tool_trace,
)
from openai import pydantic_function_tool
from openai.types.chat import ChatCompletionToolParam
from pydantic import Field, model_validator


class UCVolumeTool(UCVolumeToolMixin):
"""
An OpenAI-compatible tool for reading files from a Databricks Unity Catalog Volume.

This class integrates with Databricks UC Volumes and provides a convenient interface
for tool calling using the OpenAI SDK. Follows the same pattern as
VectorSearchRetrieverTool.

Example:
Step 1: Call model with UCVolumeTool defined

.. code-block:: python

vol_tool = UCVolumeTool(volume_name="catalog.schema.my_documents")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Read the Q4 summary from reports/q4_summary.txt"},
]
first_response = client.chat.completions.create(
model="gpt-4o", messages=messages, tools=[vol_tool.tool]
)

Step 2: Execute function code – parse the model's response and handle function calls.

.. code-block:: python

tool_call = first_response.choices[0].message.tool_calls[0]
args = json.loads(tool_call.function.arguments)
result = vol_tool.execute(file_path=args["file_path"])

Step 3: Supply model with results – so it can incorporate them into its final response.

.. code-block:: python

messages.append(first_response.choices[0].message)
messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result})
second_response = client.chat.completions.create(
model="gpt-4o", messages=messages, tools=[vol_tool.tool]
)
"""

tool: Optional[ChatCompletionToolParam] = Field(
None, description="The tool input used in the OpenAI chat completion SDK"
)

@model_validator(mode="after")
def _validate_tool_inputs(self):
if not self.workspace_client:
from databricks.sdk import WorkspaceClient

self.workspace_client = WorkspaceClient()

tool_name = self._get_tool_name()

self.tool = pydantic_function_tool(
UCVolumeToolInput,
name=tool_name,
description=self.tool_description or self._get_default_tool_description(),
)
if "function" in self.tool and "strict" in self.tool["function"]:
del self.tool["function"]["strict"]
if (
"function" in self.tool
and "parameters" in self.tool["function"]
and "additionalProperties" in self.tool["function"]["parameters"]
):
del self.tool["function"]["parameters"]["additionalProperties"]

return self

@uc_volume_tool_trace
def execute(self, file_path: str, **kwargs: Any) -> str:
"""
Execute the UCVolumeTool to read a file from the volume.

Args:
file_path: The path to the file relative to the volume root.

Returns:
The file contents as a string.
"""
return self._read_file(file_path)
Loading
Loading