-
Notifications
You must be signed in to change notification settings - Fork 51
Add UCVolumeTool for reading files from UC Volumes #323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
integrations/langchain/src/databricks_langchain/uc_volume_tool.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from typing import Type | ||
|
|
||
| from databricks_ai_bridge.uc_volume_tool import ( | ||
| UCVolumeToolInput, | ||
| UCVolumeToolMixin, | ||
| uc_volume_tool_trace, | ||
| ) | ||
| from langchain_core.tools import BaseTool | ||
| from pydantic import BaseModel, Field, model_validator | ||
|
|
||
|
|
||
| class UCVolumeTool(BaseTool, UCVolumeToolMixin): | ||
| """ | ||
| A LangChain tool for reading files from a Databricks Unity Catalog Volume. | ||
|
|
||
| This class integrates with Databricks UC Volumes and provides a convenient interface | ||
| for building a file reading tool for agents. | ||
|
|
||
| Example: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| from databricks_langchain import UCVolumeTool, ChatDatabricks | ||
|
|
||
| vol_tool = UCVolumeTool( | ||
| volume_name="catalog.schema.my_documents", | ||
| tool_name="document_reader", | ||
| tool_description="Reads files from the company documents volume.", | ||
| ) | ||
|
|
||
| # Test locally | ||
| vol_tool.invoke("reports/q4_summary.txt") | ||
|
|
||
| # Bind to LLM | ||
| llm = ChatDatabricks(endpoint="databricks-claude-sonnet-4-5") | ||
| llm_with_tools = llm.bind_tools([vol_tool]) | ||
| llm_with_tools.invoke("Read the Q4 summary from reports/q4_summary.txt") | ||
| """ | ||
|
|
||
| # BaseTool requires 'name' and 'description' fields; populated in _validate_tool_inputs() | ||
| name: str = Field(default="", description="The name of the tool") | ||
| description: str = Field(default="", description="The description of the tool") | ||
| args_schema: Type[BaseModel] = UCVolumeToolInput | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_tool_inputs(self): | ||
| self.name = self._get_tool_name() | ||
| self.description = self.tool_description or self._get_default_tool_description() | ||
| if not self.workspace_client: | ||
| from databricks.sdk import WorkspaceClient | ||
|
|
||
| self.workspace_client = WorkspaceClient() | ||
| return self | ||
|
|
||
| @uc_volume_tool_trace | ||
| def _run(self, file_path: str, **kwargs) -> str: | ||
| return self._read_file(file_path) |
120 changes: 120 additions & 0 deletions
120
integrations/langchain/tests/unit_tests/test_uc_volume_tool.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| import io | ||
| from typing import Any | ||
| from unittest.mock import MagicMock | ||
|
|
||
| import mlflow | ||
| import pytest | ||
| from databricks_ai_bridge.test_utils.uc_volume import ( # noqa: F401 | ||
| SAMPLE_FILE_CONTENT, | ||
| VOLUME_NAME, | ||
| mock_workspace_client, | ||
| ) | ||
| from langchain_core.tools import BaseTool | ||
| from mlflow.entities import SpanType | ||
|
|
||
| from databricks_langchain import UCVolumeTool | ||
|
|
||
|
|
||
| def init_volume_tool( | ||
| volume_name: str = VOLUME_NAME, | ||
| tool_name: str | None = None, | ||
| tool_description: str | None = None, | ||
| **kwargs: Any, | ||
| ) -> UCVolumeTool: | ||
| kwargs.update( | ||
| { | ||
| "volume_name": volume_name, | ||
| "tool_name": tool_name, | ||
| "tool_description": tool_description, | ||
| } | ||
| ) | ||
| return UCVolumeTool(**kwargs) | ||
|
|
||
|
|
||
| class TestInit: | ||
| def test_init_is_base_tool(self): | ||
| tool = init_volume_tool() | ||
| assert isinstance(tool, BaseTool) | ||
|
|
||
| def test_init_with_custom_name_and_description(self): | ||
| tool = init_volume_tool(tool_name="my_reader", tool_description="Reads docs") | ||
| assert tool.name == "my_reader" | ||
| assert tool.description == "Reads docs" | ||
|
|
||
| def test_init_default_name_from_volume_name(self): | ||
| tool = init_volume_tool() | ||
| assert tool.name == VOLUME_NAME.replace(".", "__") | ||
|
|
||
| def test_init_default_description(self): | ||
| tool = init_volume_tool() | ||
| assert VOLUME_NAME in tool.description | ||
| assert "Reads files" in tool.description | ||
|
|
||
|
|
||
| class TestInvoke: | ||
| def test_invoke_returns_file_content(self): | ||
| tool = init_volume_tool() | ||
| result = tool.invoke("reports/q4.txt") | ||
| assert result == SAMPLE_FILE_CONTENT | ||
|
|
||
| def test_invoke_with_dict_input(self): | ||
| tool = init_volume_tool() | ||
| input_dict: dict[str, Any] = {"file_path": "reports/q4.txt"} | ||
| result = tool.invoke(input_dict) | ||
| assert result == SAMPLE_FILE_CONTENT | ||
|
|
||
| def test_invoke_binary_file_returns_error(self, mock_workspace_client): | ||
| mock_resp = MagicMock() | ||
| mock_resp.contents = io.BytesIO(b"\x80\x81\x82\x83") | ||
| mock_workspace_client.files.download.return_value = mock_resp | ||
|
|
||
| tool = init_volume_tool() | ||
| result = tool.invoke("image.png") | ||
| assert "binary file" in result | ||
|
|
||
| def test_invoke_empty_path_returns_error(self): | ||
| tool = init_volume_tool() | ||
| input_dict: dict[str, Any] = {"file_path": ""} | ||
| result = tool.invoke(input_dict) | ||
| assert "Error" in result | ||
|
|
||
|
|
||
| class TestToolNameGeneration: | ||
| def test_default_tool_name(self): | ||
| tool = init_volume_tool(volume_name="cat.schema.vol") | ||
| assert tool.name == "cat__schema__vol" | ||
|
|
||
| @pytest.mark.parametrize("tool_name", [None, "valid_tool_name", "test_tool"]) | ||
| def test_valid_tool_names(self, tool_name): | ||
| tool = init_volume_tool(tool_name=tool_name) | ||
| assert tool.tool_name == tool_name | ||
| if tool_name: | ||
| assert tool.name == tool_name | ||
|
|
||
| @pytest.mark.parametrize("tool_name", ["test.tool.name", "tool&name"]) | ||
| def test_invalid_tool_names(self, tool_name): | ||
| with pytest.raises(ValueError): | ||
| init_volume_tool(tool_name=tool_name) | ||
|
|
||
|
|
||
| class TestArgsSchema: | ||
| def test_args_schema_has_file_path(self): | ||
| tool = init_volume_tool() | ||
| assert "file_path" in tool.args_schema.model_fields | ||
| assert tool.args_schema.model_fields["file_path"].description is not None | ||
|
|
||
|
|
||
| class TestTracing: | ||
| def test_tracing_with_default_name(self): | ||
| tool = init_volume_tool() | ||
| tool._run("reports/q4.txt") | ||
| trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) | ||
| spans = trace.search_spans(name=VOLUME_NAME, span_type=SpanType.TOOL) | ||
| assert len(spans) == 1 | ||
|
|
||
| def test_tracing_with_custom_name(self): | ||
| tool = init_volume_tool(tool_name="my_reader") | ||
| tool._run("reports/q4.txt") | ||
| trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) | ||
| spans = trace.search_spans(name="my_reader", span_type=SpanType.TOOL) | ||
| assert len(spans) == 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
integrations/openai/src/databricks_openai/uc_volume_tool.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from databricks_ai_bridge.uc_volume_tool import ( | ||
| UCVolumeToolInput, | ||
| UCVolumeToolMixin, | ||
| uc_volume_tool_trace, | ||
| ) | ||
| from openai import pydantic_function_tool | ||
| from openai.types.chat import ChatCompletionToolParam | ||
| from pydantic import Field, model_validator | ||
|
|
||
|
|
||
| class UCVolumeTool(UCVolumeToolMixin): | ||
| """ | ||
| An OpenAI-compatible tool for reading files from a Databricks Unity Catalog Volume. | ||
|
|
||
| This class integrates with Databricks UC Volumes and provides a convenient interface | ||
| for tool calling using the OpenAI SDK. Follows the same pattern as | ||
| VectorSearchRetrieverTool. | ||
|
|
||
| Example: | ||
| Step 1: Call model with UCVolumeTool defined | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| vol_tool = UCVolumeTool(volume_name="catalog.schema.my_documents") | ||
| messages = [ | ||
| {"role": "system", "content": "You are a helpful assistant."}, | ||
| {"role": "user", "content": "Read the Q4 summary from reports/q4_summary.txt"}, | ||
| ] | ||
| first_response = client.chat.completions.create( | ||
| model="gpt-4o", messages=messages, tools=[vol_tool.tool] | ||
| ) | ||
|
|
||
| Step 2: Execute function code – parse the model's response and handle function calls. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| tool_call = first_response.choices[0].message.tool_calls[0] | ||
| args = json.loads(tool_call.function.arguments) | ||
| result = vol_tool.execute(file_path=args["file_path"]) | ||
|
|
||
| Step 3: Supply model with results – so it can incorporate them into its final response. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| messages.append(first_response.choices[0].message) | ||
| messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result}) | ||
| second_response = client.chat.completions.create( | ||
| model="gpt-4o", messages=messages, tools=[vol_tool.tool] | ||
| ) | ||
| """ | ||
|
|
||
| tool: Optional[ChatCompletionToolParam] = Field( | ||
| None, description="The tool input used in the OpenAI chat completion SDK" | ||
| ) | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_tool_inputs(self): | ||
| if not self.workspace_client: | ||
| from databricks.sdk import WorkspaceClient | ||
|
|
||
| self.workspace_client = WorkspaceClient() | ||
|
|
||
| tool_name = self._get_tool_name() | ||
|
|
||
| self.tool = pydantic_function_tool( | ||
| UCVolumeToolInput, | ||
| name=tool_name, | ||
| description=self.tool_description or self._get_default_tool_description(), | ||
| ) | ||
| if "function" in self.tool and "strict" in self.tool["function"]: | ||
| del self.tool["function"]["strict"] | ||
| if ( | ||
| "function" in self.tool | ||
| and "parameters" in self.tool["function"] | ||
| and "additionalProperties" in self.tool["function"]["parameters"] | ||
| ): | ||
| del self.tool["function"]["parameters"]["additionalProperties"] | ||
|
|
||
| return self | ||
|
|
||
| @uc_volume_tool_trace | ||
| def execute(self, file_path: str, **kwargs: Any) -> str: | ||
| """ | ||
| Execute the UCVolumeTool to read a file from the volume. | ||
|
|
||
| Args: | ||
| file_path: The path to the file relative to the volume root. | ||
|
|
||
| Returns: | ||
| The file contents as a string. | ||
| """ | ||
| return self._read_file(file_path) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.