Skip to content

Commit 5d668c1

Browse files
authored
Check function signature instead of output for auto-tracing retrievers (unitycatalog#861)
**PR Checklist** - [x] A description of the changes is added to the description of this PR. - [x] If there is a related issue, make sure it is linked to this PR. - [x] If you've fixed a bug or added code that should be tested, add tests! - [ ] If you've added or modified a feature, documentation in `docs` is updated -> documenting this in docs for Agent Tools https://github.com/databricks-eng/docs/pull/19358 **Description of changes** Following up from this [comment](unitycatalog#821 (comment)), we want to modify auto-tracing for retrievers to check the function signature instead of the output format so that we can capture errors that occur for retrievers as well. Our recommended way for querying vector search in the [documentation](https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#vector-search-retriever-tool-with-unity-catalog-functions) has data type `TABLE_TYPE` and we can match on this signature. Manual testing: - Langchain <img width="1307" alt="Screenshot 2025-01-27 at 4 20 28 PM" src="https://github.com/user-attachments/assets/2765a580-67c5-4de0-9526-03097ec35cca" /> <img width="1309" alt="Screenshot 2025-01-27 at 4 21 16 PM" src="https://github.com/user-attachments/assets/cef0bf0b-6af7-4a96-b721-f5cdf2611daf" /> - CrewAI <img width="1301" alt="Screenshot 2025-01-27 at 4 53 10 PM" src="https://github.com/user-attachments/assets/3b60b494-3dd0-40e4-9cac-5674b7790b55" /> <img width="1306" alt="Screenshot 2025-01-27 at 5 21 02 PM" src="https://github.com/user-attachments/assets/533d0e12-361e-4e01-99c0-095087b555cb" /> - LlamaIndex <img width="1303" alt="Screenshot 2025-01-27 at 5 10 42 PM" src="https://github.com/user-attachments/assets/81325a91-083d-4eef-badc-908eb892a53f" /> <img width="1312" alt="Screenshot 2025-01-27 at 5 11 50 PM" src="https://github.com/user-attachments/assets/499bad04-3ee2-4966-b211-84cd5f3c908e" /> --------- Signed-off-by: Ann Zhang <[email protected]>
1 parent a8fb85b commit 5d668c1

22 files changed

+779
-303
lines changed

.github/workflows/ucai-integration-tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ jobs:
194194
pip install core/.
195195
echo "PYTHONPATH=$(pwd)/core/src:\$PYTHONPATH" >> $GITHUB_ENV
196196
pip install integrations/gemini[dev]
197+
pip install mlflow
197198
- name: Run tests
198199
run: |
199200
pytest integrations/gemini/tests --ignore=integrations/gemini/tests/test_gemini_toolkit_oss.py

ai/core/src/unitycatalog/ai/core/base.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import json
22
import logging
33
import threading
4-
import time
54
from abc import ABC, abstractmethod
65
from copy import deepcopy
76
from dataclasses import dataclass
87
from typing import Any, Callable, Dict, Literal, Optional
98

109
from unitycatalog.ai.core.paged_list import PagedList
11-
from unitycatalog.ai.core.utils.function_processing_utils import auto_trace_retriever
10+
from unitycatalog.ai.core.utils.function_processing_utils import (
11+
_execute_uc_function_with_retriever_tracing,
12+
)
13+
from unitycatalog.ai.core.utils.validation_utils import has_retriever_signature
1214

1315
_logger = logging.getLogger(__name__)
1416

@@ -157,13 +159,14 @@ def execute_function(
157159
parameters = parameters or {}
158160
self.validate_input_params(function_info.input_params, parameters)
159161

160-
start_time_ns = time.time_ns()
161-
result = self._execute_uc_function(function_info, parameters, **kwargs)
162-
end_time_ns = time.time_ns()
162+
if kwargs.get("enable_retriever_tracing", False) and has_retriever_signature(
163+
function_info
164+
):
165+
return _execute_uc_function_with_retriever_tracing(
166+
self._execute_uc_function, function_info, parameters, **kwargs
167+
)
163168

164-
if kwargs.get("enable_retriever_tracing", False):
165-
auto_trace_retriever(function_name, parameters, result, start_time_ns, end_time_ns)
166-
return result
169+
return self._execute_uc_function(function_info, parameters, **kwargs)
167170

168171
@abstractmethod
169172
def _execute_uc_function(

ai/core/src/unitycatalog/ai/core/utils/function_processing_utils.py

+66-60
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
PydanticType,
2121
)
2222
from unitycatalog.ai.core.utils.type_utils import UC_TYPE_JSON_MAPPING
23-
from unitycatalog.ai.core.utils.validation_utils import (
24-
FullFunctionName,
25-
is_valid_retriever_output,
26-
)
23+
from unitycatalog.ai.core.utils.validation_utils import FullFunctionName
2724

2825
_logger = logging.getLogger(__name__)
2926

@@ -321,67 +318,76 @@ def supported_function_info_types():
321318
return types
322319

323320

324-
def auto_trace_retriever(
325-
function_name: str,
321+
def process_retriever_output(result: "FunctionExecutionResult") -> List[Dict[str, Any]]:
322+
"""
323+
Process retriever output from result into mlflow.entities.Document format for tracing.
324+
325+
Args:
326+
result: The result of the function execution to be processed.
327+
328+
Returns:
329+
Retriever output formatted into a list of Documents.
330+
"""
331+
if result.format == "CSV":
332+
df = pd.read_csv(StringIO(result.value))
333+
if "metadata" in df.columns:
334+
df["metadata"] = df["metadata"].apply(ast.literal_eval)
335+
output = df.to_dict(orient="records")
336+
else:
337+
value = result.value
338+
output = ast.literal_eval(value) if isinstance(value, str) else value
339+
340+
return output
341+
342+
343+
def _execute_uc_function_with_retriever_tracing(
344+
_execute_uc_function: Callable,
345+
function_info: "FunctionInfo",
326346
parameters: Dict[str, Any],
327-
result: "FunctionExecutionResult",
328-
start_time_ns: int,
329-
end_time_ns: int,
330-
):
347+
**kwargs: Any,
348+
) -> "FunctionExecutionResult":
331349
"""
332-
If the given function is a retriever, trace the function given the provided start and end time.
333-
A function is considered a retriever if the result is of valid retriever output format.
350+
Executes a UC function with MLflow tracing with span type RETRIEVER enabled. If MLflow cannot
351+
be imported, the function executes without tracing and logs a warning.
334352
335353
Args:
336-
function_name: The function name.
337-
parameters: The input parameters to the function.
338-
result: The output result of the function.
339-
start_time_ns: The start time of the function in nanoseconds.
340-
end_time_ns: The end time of the function in nanoseconds.
354+
_execute_uc_function (Callable): A function that executes the given UC function.
355+
function_info (FunctionInfo): Metadata about the UC function to be executed.
356+
parameters (Dict[str, Any]): Parameters to be passed to the function during execution.
357+
**kwargs (Any): Additional keyword arguments to be passed to the function.
358+
359+
Returns:
360+
Any: The output of the function execution.
341361
"""
342362
try:
343-
if result.format == "CSV":
344-
df = pd.read_csv(StringIO(result.value))
345-
if "metadata" in df.columns:
346-
df["metadata"] = df["metadata"].apply(ast.literal_eval)
347-
output = df.to_dict(orient="records")
348-
else:
349-
value = result.value
350-
output = ast.literal_eval(value) if isinstance(value, str) else value
351-
352-
if is_valid_retriever_output(output):
353-
import mlflow
354-
from mlflow import MlflowClient
355-
from mlflow.entities import SpanType
356-
357-
client = MlflowClient()
358-
common_params = dict(
359-
name=function_name,
360-
span_type=SpanType.RETRIEVER,
361-
inputs=parameters,
362-
start_time_ns=start_time_ns,
363-
)
363+
import mlflow
364+
from mlflow.entities import SpanType
364365

365-
if parent_span := mlflow.get_current_active_span():
366-
span = client.start_span(
367-
request_id=parent_span.request_id,
368-
parent_id=parent_span.span_id,
369-
**common_params,
370-
)
371-
client.end_span(
372-
request_id=span.request_id,
373-
span_id=span.span_id,
374-
outputs=output,
375-
end_time_ns=end_time_ns,
376-
)
377-
else:
378-
span = client.start_trace(**common_params)
379-
client.end_trace(
380-
request_id=span.request_id, outputs=output, end_time_ns=end_time_ns
381-
)
382-
except Exception as e:
383-
# Ignoring exceptions because auto-tracing retriever is not essential functionality
384-
_logger.debug(
385-
f"Skipping tracing {function_name} as a retriever because of the following error:\n {e}"
366+
result = None
367+
368+
@mlflow.trace(name=function_info.full_name, span_type=SpanType.RETRIEVER)
369+
def execute_retriever(parameters):
370+
# Set inputs manually so we log {"query": "..."} instead of {"parameters": {"query": "..."}}
371+
if span := mlflow.get_current_active_span():
372+
span.set_inputs(parameters)
373+
374+
nonlocal result
375+
result = _execute_uc_function(function_info, parameters, **kwargs)
376+
377+
# Re-raise errors so they can get traced
378+
if result.error:
379+
raise Exception(result.error)
380+
381+
return process_retriever_output(result)
382+
383+
try:
384+
execute_retriever(parameters)
385+
except Exception: # Catch all errors that are re-raised
386+
pass
387+
388+
return result
389+
except ImportError as e:
390+
_logger.warn(
391+
f"Skipping tracing {function_info.full_name} as a retriever because of the following error:\n {e}"
386392
)
387-
pass
393+
return _execute_uc_function(function_info, parameters, **kwargs)

ai/core/src/unitycatalog/ai/core/utils/validation_utils.py

+13-20
Original file line numberDiff line numberDiff line change
@@ -142,39 +142,32 @@ def validate_function_name_length(function_name: str) -> None:
142142
)
143143

144144

145-
def is_valid_retriever_output(output: Any) -> bool:
145+
def has_retriever_signature(function_info: "FunctionInfo") -> bool:
146146
"""
147-
Checks if the given output follows the retriever format for MLflow, which is a list of Documents
148-
or dictionaries that follow the Document format.
147+
Checks if the given function signature follows the retriever format for MLflow, which is a
148+
list of Documents.
149149
150150
Args:
151-
output: The value to determine if it is a valid retriever output.
151+
function_info: The function to determine if it has a valid retriever signature.
152152
153153
Returns:
154-
bool: If the provided output is a valid retriever output.
154+
bool: If the provided function has a valid retriever signature.
155155
"""
156-
if not isinstance(output, list):
156+
if "TABLE_TYPE" not in str(function_info.data_type):
157157
return False
158158

159-
if len(output) < 1:
159+
return_params = function_info.return_params
160+
161+
if not (return_params and return_params.parameters):
160162
return False
161163

162-
def is_valid_retriever_item(item: Any) -> bool:
163-
from mlflow.entities import Document
164+
for param in return_params.parameters:
165+
param_dict = param.as_dict() if hasattr(param, "as_dict") else dict(param)
164166

165-
if isinstance(item, Document):
167+
if param_dict.get("name") == "page_content" and param_dict.get("type_name") == "STRING":
166168
return True
167169

168-
if isinstance(item, dict):
169-
try:
170-
Document(**item)
171-
return True
172-
except TypeError:
173-
return False
174-
175-
return False
176-
177-
return all(is_valid_retriever_item(item) for item in output)
170+
return False
178171

179172

180173
def mlflow_tracing_enabled(integration_name: str) -> bool:

ai/core/src/unitycatalog/ai/test_utils/function_utils.py

+54
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,68 @@
33
from contextlib import contextmanager
44
from typing import Any, Callable, Generator, NamedTuple, Optional
55

6+
from databricks.sdk.service.catalog import (
7+
ColumnTypeName,
8+
FunctionParameterInfo,
9+
FunctionParameterInfos,
10+
)
11+
612
from unitycatalog.ai.core.databricks import DatabricksFunctionClient
713
from unitycatalog.ai.core.utils.function_processing_utils import get_tool_name
14+
from unitycatalog.client import (
15+
ColumnTypeName as OSSColumnTypeName,
16+
)
17+
from unitycatalog.client import (
18+
FunctionParameterInfo as OSSFunctionParameterInfo,
19+
)
20+
from unitycatalog.client import (
21+
FunctionParameterInfos as OSSFunctionParameterInfos,
22+
)
823

924
CATALOG = "integration_testing"
1025

1126
RETRIEVER_OUTPUT_SCALAR = '[{"page_content": "# Technology partners\\n## What is Databricks Partner Connect?\\n", "metadata": {"similarity_score": 0.010178182, "chunk_id": "0217a07ba2fec61865ce408043acf1cf"}}, {"page_content": "# Technology partners\\n## What is Databricks?\\n", "metadata": {"similarity_score": 0.010178183, "chunk_id": "0217a07ba2fec61865ce408043acf1cd"}}]'
1227
RETRIEVER_OUTPUT_CSV = "page_content,metadata\n\"# Technology partners\n## What is Databricks Partner Connect?\n\",\"{'similarity_score': 0.010178182, 'chunk_id': '0217a07ba2fec61865ce408043acf1cf'}\"\n\"# Technology partners\n## What is Databricks?\n\",\"{'similarity_score': 0.010178183, 'chunk_id': '0217a07ba2fec61865ce408043acf1cd'}\"\n"
1328

29+
RETRIEVER_TABLE_FULL_DATA_TYPE = "(page_content STRING, metadata MAP<STRING, STRING>)"
30+
RETRIEVER_TABLE_RETURN_PARAMS = FunctionParameterInfos(
31+
parameters=[
32+
FunctionParameterInfo(
33+
name="page_content",
34+
type_text="string",
35+
type_name=ColumnTypeName.STRING,
36+
type_json='{"name":"page_content","type":"string","nullable":true,"metadata":{}}',
37+
position=0,
38+
),
39+
FunctionParameterInfo(
40+
name="metadata",
41+
type_text="map<string,string>",
42+
type_name=ColumnTypeName.MAP,
43+
type_json='{"name":"metadata","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}}',
44+
position=1,
45+
),
46+
]
47+
)
48+
RETRIEVER_TABLE_RETURN_PARAMS_OSS = OSSFunctionParameterInfos(
49+
parameters=[
50+
OSSFunctionParameterInfo(
51+
name="page_content",
52+
type_text="string",
53+
type_name=OSSColumnTypeName.STRING,
54+
type_json='{"name":"page_content","type":"string","nullable":true,"metadata":{}}',
55+
position=0,
56+
),
57+
OSSFunctionParameterInfo(
58+
name="metadata",
59+
type_text="map<string,string>",
60+
type_name=OSSColumnTypeName.MAP,
61+
type_json='{"name":"metadata","type":{"type":"map","keyType":"string","valueType":"string","valueContainsNull":true},"nullable":true,"metadata":{}}',
62+
position=1,
63+
),
64+
]
65+
)
66+
67+
1468
_logger = logging.getLogger(__name__)
1569

1670

0 commit comments

Comments
 (0)