diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index bfa52f8cb..8ab81b596 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -27,6 +27,7 @@ MCPServer, ) from databricks_langchain.store import AsyncDatabricksStore, DatabricksStore +from databricks_langchain.uc_volume_tool import UCVolumeTool from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool from databricks_langchain.vectorstores import DatabricksVectorSearch @@ -40,6 +41,7 @@ "DatabricksStore", "DatabricksVectorSearch", "GenieAgent", + "UCVolumeTool", "VectorSearchRetrieverTool", "UCFunctionToolkit", "UnityCatalogTool", diff --git a/integrations/langchain/src/databricks_langchain/uc_volume_tool.py b/integrations/langchain/src/databricks_langchain/uc_volume_tool.py new file mode 100644 index 000000000..cb071b690 --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/uc_volume_tool.py @@ -0,0 +1,57 @@ +from typing import Type + +from databricks_ai_bridge.uc_volume_tool import ( + UCVolumeToolInput, + UCVolumeToolMixin, + uc_volume_tool_trace, +) +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field, model_validator + + +class UCVolumeTool(BaseTool, UCVolumeToolMixin): + """ + A LangChain tool for reading files from a Databricks Unity Catalog Volume. + + This class integrates with Databricks UC Volumes and provides a convenient interface + for building a file reading tool for agents. + + Example: + + .. code-block:: python + + from databricks_langchain import UCVolumeTool, ChatDatabricks + + vol_tool = UCVolumeTool( + volume_name="catalog.schema.my_documents", + tool_name="document_reader", + tool_description="Reads files from the company documents volume.", + ) + + # Test locally + vol_tool.invoke("reports/q4_summary.txt") + + # Bind to LLM + llm = ChatDatabricks(endpoint="databricks-claude-sonnet-4-5") + llm_with_tools = llm.bind_tools([vol_tool]) + llm_with_tools.invoke("Read the Q4 summary from reports/q4_summary.txt") + """ + + # BaseTool requires 'name' and 'description' fields; populated in _validate_tool_inputs() + name: str = Field(default="", description="The name of the tool") + description: str = Field(default="", description="The description of the tool") + args_schema: Type[BaseModel] = UCVolumeToolInput + + @model_validator(mode="after") + def _validate_tool_inputs(self): + self.name = self._get_tool_name() + self.description = self.tool_description or self._get_default_tool_description() + if not self.workspace_client: + from databricks.sdk import WorkspaceClient + + self.workspace_client = WorkspaceClient() + return self + + @uc_volume_tool_trace + def _run(self, file_path: str, **kwargs) -> str: + return self._read_file(file_path) diff --git a/integrations/langchain/tests/unit_tests/test_uc_volume_tool.py b/integrations/langchain/tests/unit_tests/test_uc_volume_tool.py new file mode 100644 index 000000000..b3b83787b --- /dev/null +++ b/integrations/langchain/tests/unit_tests/test_uc_volume_tool.py @@ -0,0 +1,120 @@ +import io +from typing import Any +from unittest.mock import MagicMock + +import mlflow +import pytest +from databricks_ai_bridge.test_utils.uc_volume import ( # noqa: F401 + SAMPLE_FILE_CONTENT, + VOLUME_NAME, + mock_workspace_client, +) +from langchain_core.tools import BaseTool +from mlflow.entities import SpanType + +from databricks_langchain import UCVolumeTool + + +def init_volume_tool( + volume_name: str = VOLUME_NAME, + tool_name: str | None = None, + tool_description: str | None = None, + **kwargs: Any, +) -> UCVolumeTool: + kwargs.update( + { + "volume_name": volume_name, + "tool_name": tool_name, + "tool_description": tool_description, + } + ) + return UCVolumeTool(**kwargs) + + +class TestInit: + def test_init_is_base_tool(self): + tool = init_volume_tool() + assert isinstance(tool, BaseTool) + + def test_init_with_custom_name_and_description(self): + tool = init_volume_tool(tool_name="my_reader", tool_description="Reads docs") + assert tool.name == "my_reader" + assert tool.description == "Reads docs" + + def test_init_default_name_from_volume_name(self): + tool = init_volume_tool() + assert tool.name == VOLUME_NAME.replace(".", "__") + + def test_init_default_description(self): + tool = init_volume_tool() + assert VOLUME_NAME in tool.description + assert "Reads files" in tool.description + + +class TestInvoke: + def test_invoke_returns_file_content(self): + tool = init_volume_tool() + result = tool.invoke("reports/q4.txt") + assert result == SAMPLE_FILE_CONTENT + + def test_invoke_with_dict_input(self): + tool = init_volume_tool() + input_dict: dict[str, Any] = {"file_path": "reports/q4.txt"} + result = tool.invoke(input_dict) + assert result == SAMPLE_FILE_CONTENT + + def test_invoke_binary_file_returns_error(self, mock_workspace_client): + mock_resp = MagicMock() + mock_resp.contents = io.BytesIO(b"\x80\x81\x82\x83") + mock_workspace_client.files.download.return_value = mock_resp + + tool = init_volume_tool() + result = tool.invoke("image.png") + assert "binary file" in result + + def test_invoke_empty_path_returns_error(self): + tool = init_volume_tool() + input_dict: dict[str, Any] = {"file_path": ""} + result = tool.invoke(input_dict) + assert "Error" in result + + +class TestToolNameGeneration: + def test_default_tool_name(self): + tool = init_volume_tool(volume_name="cat.schema.vol") + assert tool.name == "cat__schema__vol" + + @pytest.mark.parametrize("tool_name", [None, "valid_tool_name", "test_tool"]) + def test_valid_tool_names(self, tool_name): + tool = init_volume_tool(tool_name=tool_name) + assert tool.tool_name == tool_name + if tool_name: + assert tool.name == tool_name + + @pytest.mark.parametrize("tool_name", ["test.tool.name", "tool&name"]) + def test_invalid_tool_names(self, tool_name): + with pytest.raises(ValueError): + init_volume_tool(tool_name=tool_name) + + +class TestArgsSchema: + def test_args_schema_has_file_path(self): + tool = init_volume_tool() + assert "file_path" in tool.args_schema.model_fields + assert tool.args_schema.model_fields["file_path"].description is not None + + +class TestTracing: + def test_tracing_with_default_name(self): + tool = init_volume_tool() + tool._run("reports/q4.txt") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + spans = trace.search_spans(name=VOLUME_NAME, span_type=SpanType.TOOL) + assert len(spans) == 1 + + def test_tracing_with_custom_name(self): + tool = init_volume_tool(tool_name="my_reader") + tool._run("reports/q4.txt") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + spans = trace.search_spans(name="my_reader", span_type=SpanType.TOOL) + assert len(spans) == 1 diff --git a/integrations/openai/src/databricks_openai/__init__.py b/integrations/openai/src/databricks_openai/__init__.py index ce04d0c01..e8ef0428f 100644 --- a/integrations/openai/src/databricks_openai/__init__.py +++ b/integrations/openai/src/databricks_openai/__init__.py @@ -17,11 +17,13 @@ from unitycatalog.ai.openai.toolkit import UCFunctionToolkit from databricks_openai.mcp_server_toolkit import McpServerToolkit, ToolInfo +from databricks_openai.uc_volume_tool import UCVolumeTool from databricks_openai.utils.clients import AsyncDatabricksOpenAI, DatabricksOpenAI from databricks_openai.vector_search_retriever_tool import VectorSearchRetrieverTool # Expose all integrations to users under databricks-openai __all__ = [ + "UCVolumeTool", "VectorSearchRetrieverTool", "UCFunctionToolkit", "DatabricksFunctionClient", diff --git a/integrations/openai/src/databricks_openai/uc_volume_tool.py b/integrations/openai/src/databricks_openai/uc_volume_tool.py new file mode 100644 index 000000000..1524d0d2c --- /dev/null +++ b/integrations/openai/src/databricks_openai/uc_volume_tool.py @@ -0,0 +1,94 @@ +from typing import Any, Optional + +from databricks_ai_bridge.uc_volume_tool import ( + UCVolumeToolInput, + UCVolumeToolMixin, + uc_volume_tool_trace, +) +from openai import pydantic_function_tool +from openai.types.chat import ChatCompletionToolParam +from pydantic import Field, model_validator + + +class UCVolumeTool(UCVolumeToolMixin): + """ + An OpenAI-compatible tool for reading files from a Databricks Unity Catalog Volume. + + This class integrates with Databricks UC Volumes and provides a convenient interface + for tool calling using the OpenAI SDK. Follows the same pattern as + VectorSearchRetrieverTool. + + Example: + Step 1: Call model with UCVolumeTool defined + + .. code-block:: python + + vol_tool = UCVolumeTool(volume_name="catalog.schema.my_documents") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Read the Q4 summary from reports/q4_summary.txt"}, + ] + first_response = client.chat.completions.create( + model="gpt-4o", messages=messages, tools=[vol_tool.tool] + ) + + Step 2: Execute function code – parse the model's response and handle function calls. + + .. code-block:: python + + tool_call = first_response.choices[0].message.tool_calls[0] + args = json.loads(tool_call.function.arguments) + result = vol_tool.execute(file_path=args["file_path"]) + + Step 3: Supply model with results – so it can incorporate them into its final response. + + .. code-block:: python + + messages.append(first_response.choices[0].message) + messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result}) + second_response = client.chat.completions.create( + model="gpt-4o", messages=messages, tools=[vol_tool.tool] + ) + """ + + tool: Optional[ChatCompletionToolParam] = Field( + None, description="The tool input used in the OpenAI chat completion SDK" + ) + + @model_validator(mode="after") + def _validate_tool_inputs(self): + if not self.workspace_client: + from databricks.sdk import WorkspaceClient + + self.workspace_client = WorkspaceClient() + + tool_name = self._get_tool_name() + + self.tool = pydantic_function_tool( + UCVolumeToolInput, + name=tool_name, + description=self.tool_description or self._get_default_tool_description(), + ) + if "function" in self.tool and "strict" in self.tool["function"]: + del self.tool["function"]["strict"] + if ( + "function" in self.tool + and "parameters" in self.tool["function"] + and "additionalProperties" in self.tool["function"]["parameters"] + ): + del self.tool["function"]["parameters"]["additionalProperties"] + + return self + + @uc_volume_tool_trace + def execute(self, file_path: str, **kwargs: Any) -> str: + """ + Execute the UCVolumeTool to read a file from the volume. + + Args: + file_path: The path to the file relative to the volume root. + + Returns: + The file contents as a string. + """ + return self._read_file(file_path) diff --git a/integrations/openai/tests/unit_tests/test_uc_volume_tool.py b/integrations/openai/tests/unit_tests/test_uc_volume_tool.py new file mode 100644 index 000000000..7f44b7284 --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_uc_volume_tool.py @@ -0,0 +1,161 @@ +import io +import json +from typing import Any, Dict, Optional, cast +from unittest.mock import MagicMock + +import mlflow +import pytest +from databricks_ai_bridge.test_utils.uc_volume import ( # noqa: F401 + SAMPLE_FILE_CONTENT, + VOLUME_NAME, + mock_workspace_client, +) +from mlflow.entities import SpanType +from pydantic import BaseModel + +from databricks_openai import UCVolumeTool + + +def init_volume_tool( + volume_name: str = VOLUME_NAME, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + **kwargs: Any, +) -> UCVolumeTool: + kwargs.update( + { + "volume_name": volume_name, + "tool_name": tool_name, + "tool_description": tool_description, + } + ) + return UCVolumeTool(**kwargs) + + +class TestInit: + def test_init_is_base_model(self): + tool = init_volume_tool() + assert isinstance(tool, BaseModel) + + def test_init_creates_tool_spec(self): + tool = init_volume_tool() + assert tool.tool is not None + tool_spec = tool.tool + assert tool_spec["type"] == "function" + assert "function" in tool_spec + + def test_init_with_custom_name_and_description(self): + tool = init_volume_tool(tool_name="my_reader", tool_description="Reads docs") + assert tool.tool is not None + tool_spec = tool.tool + assert tool_spec["function"]["name"] == "my_reader" + assert tool_spec["function"]["description"] == "Reads docs" + + def test_init_default_name_from_volume_name(self): + tool = init_volume_tool() + assert tool.tool is not None + assert tool.tool["function"]["name"] == VOLUME_NAME.replace(".", "__") + + def test_init_no_strict_mode(self): + tool = init_volume_tool() + assert tool.tool is not None + assert "strict" not in tool.tool.get("function", {}) + + def test_init_no_additional_properties(self): + tool = init_volume_tool() + assert tool.tool is not None + assert "additionalProperties" not in tool.tool["function"]["parameters"] + + +class TestToolSchema: + def test_tool_schema_has_file_path(self): + tool = init_volume_tool() + schema = cast(Dict[str, Any], tool.tool) + properties = schema["function"]["parameters"]["properties"] + assert "file_path" in properties + assert "description" in properties["file_path"] + + def test_tool_schema_file_path_is_required(self): + tool = init_volume_tool() + schema = cast(Dict[str, Any], tool.tool) + required = schema["function"]["parameters"].get("required", []) + assert "file_path" in required + + +class TestExecute: + def test_execute_returns_file_content(self): + tool = init_volume_tool() + result = tool.execute(file_path="reports/q4.txt") + assert result == SAMPLE_FILE_CONTENT + + def test_execute_binary_file_returns_error(self, mock_workspace_client): + mock_resp = MagicMock() + mock_resp.contents = io.BytesIO(b"\x80\x81\x82\x83") + mock_workspace_client.files.download.return_value = mock_resp + + tool = init_volume_tool() + result = tool.execute(file_path="image.png") + assert "binary file" in result + + def test_execute_empty_path_returns_error(self): + tool = init_volume_tool() + result = tool.execute(file_path="") + assert "Error" in result + + def test_execute_calls_correct_volume_path(self, mock_workspace_client): + tool = init_volume_tool(volume_name="cat.schema.vol") + tool.execute(file_path="subfolder/file.txt") + mock_workspace_client.files.download.assert_called_once_with( + "/Volumes/cat/schema/vol/subfolder/file.txt" + ) + + +class TestToolNameGeneration: + def test_default_tool_name(self): + tool = init_volume_tool() + assert tool.tool is not None + assert tool.tool["function"]["name"] == VOLUME_NAME.replace(".", "__") + + @pytest.mark.parametrize( + "volume_name", + [ + "catalog.schema.really_really_really_long_volume_name_that_should_be_truncated_to_64_chars" + ], + ) + def test_long_volume_name_truncated(self, volume_name): + tool = init_volume_tool(volume_name=volume_name) + assert tool.tool is not None + assert len(tool.tool["function"]["name"]) <= 64 + + @pytest.mark.parametrize("tool_name", [None, "valid_tool_name", "test_tool"]) + def test_valid_tool_names(self, tool_name): + tool = init_volume_tool(tool_name=tool_name) + assert tool.tool_name == tool_name + + @pytest.mark.parametrize("tool_name", ["test.tool.name", "tool&name"]) + def test_invalid_tool_names(self, tool_name): + with pytest.raises(ValueError): + init_volume_tool(tool_name=tool_name) + + +class TestTracing: + def test_tracing_with_default_name(self): + tool = init_volume_tool() + tool.execute(file_path="reports/q4.txt") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + spans = trace.search_spans(name=VOLUME_NAME, span_type=SpanType.TOOL) + assert len(spans) == 1 + + def test_tracing_with_custom_name(self): + tool = init_volume_tool(tool_name="my_reader") + tool.execute(file_path="reports/q4.txt") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + spans = trace.search_spans(name="my_reader", span_type=SpanType.TOOL) + assert len(spans) == 1 + + def test_tracing_captures_inputs(self): + tool = init_volume_tool() + tool.execute(file_path="reports/q4.txt") + trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) + inputs = json.loads(trace.to_dict()["data"]["spans"][0]["attributes"]["mlflow.spanInputs"]) + assert inputs["file_path"] == "reports/q4.txt" diff --git a/src/databricks_ai_bridge/test_utils/uc_volume.py b/src/databricks_ai_bridge/test_utils/uc_volume.py new file mode 100644 index 000000000..cd888f830 --- /dev/null +++ b/src/databricks_ai_bridge/test_utils/uc_volume.py @@ -0,0 +1,32 @@ +import io +from typing import Generator +from unittest.mock import MagicMock, patch + +import pytest + +VOLUME_NAME = "test_catalog.test_schema.test_volume" + +SAMPLE_FILE_CONTENT = "This is the content of the test file." +SAMPLE_FILE_CONTENT_BYTES = SAMPLE_FILE_CONTENT.encode("utf-8") + + +def _make_download_response(content_bytes: bytes = SAMPLE_FILE_CONTENT_BYTES) -> MagicMock: + """Create a mock response for workspace_client.files.download().""" + mock_resp = MagicMock() + mock_resp.contents = io.BytesIO(content_bytes) + return mock_resp + + +@pytest.fixture(autouse=True) +def mock_workspace_client() -> Generator: + """Mock WorkspaceClient for UC Volume operations.""" + mock_client = MagicMock() + + # Mock files.download + mock_client.files.download.return_value = _make_download_response() + + with patch( + "databricks.sdk.WorkspaceClient", + return_value=mock_client, + ): + yield mock_client diff --git a/src/databricks_ai_bridge/uc_volume_tool.py b/src/databricks_ai_bridge/uc_volume_tool.py new file mode 100644 index 000000000..4e3d04c0f --- /dev/null +++ b/src/databricks_ai_bridge/uc_volume_tool.py @@ -0,0 +1,113 @@ +import logging +import re +from functools import wraps +from typing import Optional + +import mlflow +from databricks.sdk import WorkspaceClient +from mlflow.entities import SpanType +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from databricks_ai_bridge.utils.uc_volume import read_volume_file + +_logger = logging.getLogger(__name__) +DEFAULT_TOOL_DESCRIPTION = "A tool for reading files from a Databricks Unity Catalog Volume." + + +def uc_volume_tool_trace(func): + """ + Decorator factory to trace UCVolumeTool with the tool name. + Parallels vector_search_retriever_tool_trace. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + traced_func = mlflow.trace( + name=self.tool_name or self.volume_name, span_type=SpanType.TOOL + )(func) + return traced_func(self, *args, **kwargs) + + return wrapper + + +class UCVolumeToolInput(BaseModel): + """Input schema that the LLM sees and generates.""" + + model_config = ConfigDict(extra="allow") + file_path: str = Field( + description=( + "The path to the file to read, relative to the volume root. " + "For example: 'reports/q4_summary.txt' or 'data/config.json'." + ) + ) + + +class UCVolumeToolMixin(BaseModel): + """ + Mixin class for Databricks UC Volume tools. + This class provides the common structure and interface that framework-specific + implementations (LangChain, OpenAI) should follow. + Parallels VectorSearchRetrieverToolMixin. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + volume_name: str = Field(..., description="The full volume name: 'catalog.schema.volume'.") + tool_name: Optional[str] = Field(None, description="The name of the tool.") + tool_description: Optional[str] = Field(None, description="A description of the tool.") + workspace_client: Optional[WorkspaceClient] = Field( + None, + description="Optional pre-configured WorkspaceClient for authentication.", + ) + + @model_validator(mode="after") + def _validate_volume_name(self): + parts = self.volume_name.split(".") + if len(parts) != 3: + raise ValueError( + f"volume_name must be 'catalog.schema.volume', got '{self.volume_name}'" + ) + return self + + @field_validator("tool_name") + def validate_tool_name(cls, tool_name): + if tool_name is not None: + pattern = re.compile(r"^[a-zA-Z0-9_-]{1,64}$") + if not pattern.fullmatch(tool_name): + raise ValueError("tool_name must match the pattern '^[a-zA-Z0-9_-]{1,64}$'") + return tool_name + + def _get_tool_name(self) -> str: + tool_name = self.tool_name or self.volume_name.replace(".", "__") + if len(tool_name) > 64: + _logger.warning( + f"Tool name {tool_name} is too long, truncating to 64 characters {tool_name[-64:]}." + ) + return tool_name[-64:] + return tool_name + + def _get_default_tool_description(self) -> str: + return ( + f"{DEFAULT_TOOL_DESCRIPTION} " + f"Reads files from the Unity Catalog volume '{self.volume_name}'. " + f"Provide the file path relative to the volume root." + ) + + def _read_file(self, file_path: str) -> str: + """ + Core execution logic shared across frameworks. + Reads a file from the volume and returns its text content. + """ + from databricks.sdk import WorkspaceClient + + wc = self.workspace_client or WorkspaceClient() + if not file_path: + return "Error: file_path is required." + try: + result = read_volume_file(self.volume_name, file_path, workspace_client=wc) + except UnicodeDecodeError: + return ( + f"Cannot read '{file_path}': binary file. This tool supports text-based files only." + ) + # read_volume_file returns str by default (as_bytes=False) + assert isinstance(result, str) + return result diff --git a/src/databricks_ai_bridge/utils/uc_volume.py b/src/databricks_ai_bridge/utils/uc_volume.py new file mode 100644 index 000000000..ff185e01b --- /dev/null +++ b/src/databricks_ai_bridge/utils/uc_volume.py @@ -0,0 +1,55 @@ +import logging +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from databricks.sdk import WorkspaceClient + +_logger = logging.getLogger(__name__) + + +def _to_volume_path(volume_name: str, file_path: str = "") -> str: + """ + Convert 'catalog.schema.volume' + 'path/to/file' + to '/Volumes/catalog/schema/volume/path/to/file'. + """ + parts = volume_name.split(".") + if len(parts) != 3: + raise ValueError(f"volume_name must be 'catalog.schema.volume', got '{volume_name}'") + base = f"/Volumes/{parts[0]}/{parts[1]}/{parts[2]}" + if file_path: + return f"{base}/{file_path.lstrip('/')}" + return base + + +def read_volume_file( + volume_name: str, + file_path: str, + *, + encoding: str = "utf-8", + as_bytes: bool = False, + workspace_client: Optional["WorkspaceClient"] = None, +) -> Union[str, bytes]: + """ + Read a file from a Unity Catalog Volume. + + Args: + volume_name: Full volume name as 'catalog.schema.volume'. + file_path: Relative path to the file within the volume. + encoding: Text encoding (default: utf-8). Ignored if as_bytes=True. + as_bytes: If True, return raw bytes instead of decoded string. + workspace_client: Optional pre-configured WorkspaceClient. + + Returns: + File contents as a string (default) or bytes (if as_bytes=True). + """ + from databricks.sdk import WorkspaceClient + + client = workspace_client or WorkspaceClient() + full_path = _to_volume_path(volume_name, file_path) + resp = client.files.download(full_path) + if resp.contents is None: + raise ValueError(f"No content returned for '{full_path}'") + content = resp.contents.read() + if as_bytes: + return content + return content.decode(encoding) diff --git a/tests/databricks_ai_bridge/test_uc_volume_tool.py b/tests/databricks_ai_bridge/test_uc_volume_tool.py new file mode 100644 index 000000000..78a0c0c4d --- /dev/null +++ b/tests/databricks_ai_bridge/test_uc_volume_tool.py @@ -0,0 +1,122 @@ +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +from databricks_ai_bridge.test_utils.uc_volume import mock_workspace_client # noqa: F401 +from databricks_ai_bridge.uc_volume_tool import UCVolumeToolMixin + +VOLUME_NAME = "test_catalog.test_schema.test_volume" + + +class DummyUCVolumeTool(UCVolumeToolMixin): + pass + + +class TestVolumeNameValidation: + def test_valid_volume_name(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME) + assert tool.volume_name == VOLUME_NAME + + def test_invalid_volume_name_two_parts(self): + with pytest.raises(ValidationError): + DummyUCVolumeTool(volume_name="catalog.schema") + + def test_invalid_volume_name_one_part(self): + with pytest.raises(ValidationError): + DummyUCVolumeTool(volume_name="just_a_name") + + def test_invalid_volume_name_four_parts(self): + with pytest.raises(ValidationError): + DummyUCVolumeTool(volume_name="a.b.c.d") + + +class TestToolNameValidation: + def test_valid_tool_name(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME, tool_name="my_tool") + assert tool.tool_name == "my_tool" + + def test_valid_tool_name_with_hyphens(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME, tool_name="my-tool-123") + assert tool.tool_name == "my-tool-123" + + def test_invalid_tool_name_special_chars(self): + with pytest.raises(ValidationError): + DummyUCVolumeTool(volume_name=VOLUME_NAME, tool_name="invalid@@@") + + def test_invalid_tool_name_dots(self): + with pytest.raises(ValidationError): + DummyUCVolumeTool(volume_name=VOLUME_NAME, tool_name="tool.name") + + def test_none_tool_name_is_valid(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME, tool_name=None) + assert tool.tool_name is None + + +class TestGetToolName: + def test_default_tool_name_from_volume_name(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME) + assert tool._get_tool_name() == "test_catalog__test_schema__test_volume" + + def test_custom_tool_name(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME, tool_name="my_reader") + assert tool._get_tool_name() == "my_reader" + + def test_long_tool_name_truncated(self): + tool = DummyUCVolumeTool( + volume_name="catalog.schema.really_really_really_long_volume_name_that_should_be_truncated_to_64_chars" + ) + name = tool._get_tool_name() + assert len(name) <= 64 + + @pytest.mark.parametrize( + "volume_name,expected", + [ + ("cat.schema.vol", "cat__schema__vol"), + ("my_cat.my_schema.my_vol", "my_cat__my_schema__my_vol"), + ], + ) + def test_name_derivation(self, volume_name, expected): + tool = DummyUCVolumeTool(volume_name=volume_name) + assert tool._get_tool_name() == expected + + +class TestGetDefaultToolDescription: + def test_includes_volume_name(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME) + desc = tool._get_default_tool_description() + assert VOLUME_NAME in desc + assert "Reads files" in desc + + +class TestReadFile: + def test_read_file_returns_content(self, mock_workspace_client): + # The autouse fixture patches WorkspaceClient() globally + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME) + result = tool._read_file("reports/q4.txt") + assert result == "This is the content of the test file." + + def test_read_file_empty_path_returns_error(self): + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME) + result = tool._read_file("") + assert "Error" in result + + def test_read_file_binary_returns_error(self, mock_workspace_client): + import io + + # Make download return non-utf8 bytes + mock_resp = MagicMock() + mock_resp.contents = io.BytesIO(b"\x80\x81\x82\x83") + mock_workspace_client.files.download.return_value = mock_resp + + tool = DummyUCVolumeTool(volume_name=VOLUME_NAME) + result = tool._read_file("image.png") + assert "binary file" in result + assert "text-based files only" in result + + def test_read_file_calls_correct_path(self, mock_workspace_client): + tool = DummyUCVolumeTool(volume_name="cat.schema.vol") + tool._read_file("subfolder/file.txt") + mock_workspace_client.files.download.assert_called_once_with( + "/Volumes/cat/schema/vol/subfolder/file.txt" + )