|
20 | 20 | PydanticType,
|
21 | 21 | )
|
22 | 22 | from unitycatalog.ai.core.utils.type_utils import UC_TYPE_JSON_MAPPING
|
23 |
| -from unitycatalog.ai.core.utils.validation_utils import ( |
24 |
| - FullFunctionName, |
25 |
| - is_valid_retriever_output, |
26 |
| -) |
| 23 | +from unitycatalog.ai.core.utils.validation_utils import FullFunctionName |
27 | 24 |
|
28 | 25 | _logger = logging.getLogger(__name__)
|
29 | 26 |
|
@@ -321,67 +318,76 @@ def supported_function_info_types():
|
321 | 318 | return types
|
322 | 319 |
|
323 | 320 |
|
324 |
| -def auto_trace_retriever( |
325 |
| - function_name: str, |
| 321 | +def process_retriever_output(result: "FunctionExecutionResult") -> List[Dict[str, Any]]: |
| 322 | + """ |
| 323 | + Process retriever output from result into mlflow.entities.Document format for tracing. |
| 324 | +
|
| 325 | + Args: |
| 326 | + result: The result of the function execution to be processed. |
| 327 | +
|
| 328 | + Returns: |
| 329 | + Retriever output formatted into a list of Documents. |
| 330 | + """ |
| 331 | + if result.format == "CSV": |
| 332 | + df = pd.read_csv(StringIO(result.value)) |
| 333 | + if "metadata" in df.columns: |
| 334 | + df["metadata"] = df["metadata"].apply(ast.literal_eval) |
| 335 | + output = df.to_dict(orient="records") |
| 336 | + else: |
| 337 | + value = result.value |
| 338 | + output = ast.literal_eval(value) if isinstance(value, str) else value |
| 339 | + |
| 340 | + return output |
| 341 | + |
| 342 | + |
| 343 | +def _execute_uc_function_with_retriever_tracing( |
| 344 | + _execute_uc_function: Callable, |
| 345 | + function_info: "FunctionInfo", |
326 | 346 | parameters: Dict[str, Any],
|
327 |
| - result: "FunctionExecutionResult", |
328 |
| - start_time_ns: int, |
329 |
| - end_time_ns: int, |
330 |
| -): |
| 347 | + **kwargs: Any, |
| 348 | +) -> "FunctionExecutionResult": |
331 | 349 | """
|
332 |
| - If the given function is a retriever, trace the function given the provided start and end time. |
333 |
| - A function is considered a retriever if the result is of valid retriever output format. |
| 350 | + Executes a UC function with MLflow tracing with span type RETRIEVER enabled. If MLflow cannot |
| 351 | + be imported, the function executes without tracing and logs a warning. |
334 | 352 |
|
335 | 353 | Args:
|
336 |
| - function_name: The function name. |
337 |
| - parameters: The input parameters to the function. |
338 |
| - result: The output result of the function. |
339 |
| - start_time_ns: The start time of the function in nanoseconds. |
340 |
| - end_time_ns: The end time of the function in nanoseconds. |
| 354 | + _execute_uc_function (Callable): A function that executes the given UC function. |
| 355 | + function_info (FunctionInfo): Metadata about the UC function to be executed. |
| 356 | + parameters (Dict[str, Any]): Parameters to be passed to the function during execution. |
| 357 | + **kwargs (Any): Additional keyword arguments to be passed to the function. |
| 358 | +
|
| 359 | + Returns: |
| 360 | + Any: The output of the function execution. |
341 | 361 | """
|
342 | 362 | try:
|
343 |
| - if result.format == "CSV": |
344 |
| - df = pd.read_csv(StringIO(result.value)) |
345 |
| - if "metadata" in df.columns: |
346 |
| - df["metadata"] = df["metadata"].apply(ast.literal_eval) |
347 |
| - output = df.to_dict(orient="records") |
348 |
| - else: |
349 |
| - value = result.value |
350 |
| - output = ast.literal_eval(value) if isinstance(value, str) else value |
351 |
| - |
352 |
| - if is_valid_retriever_output(output): |
353 |
| - import mlflow |
354 |
| - from mlflow import MlflowClient |
355 |
| - from mlflow.entities import SpanType |
356 |
| - |
357 |
| - client = MlflowClient() |
358 |
| - common_params = dict( |
359 |
| - name=function_name, |
360 |
| - span_type=SpanType.RETRIEVER, |
361 |
| - inputs=parameters, |
362 |
| - start_time_ns=start_time_ns, |
363 |
| - ) |
| 363 | + import mlflow |
| 364 | + from mlflow.entities import SpanType |
364 | 365 |
|
365 |
| - if parent_span := mlflow.get_current_active_span(): |
366 |
| - span = client.start_span( |
367 |
| - request_id=parent_span.request_id, |
368 |
| - parent_id=parent_span.span_id, |
369 |
| - **common_params, |
370 |
| - ) |
371 |
| - client.end_span( |
372 |
| - request_id=span.request_id, |
373 |
| - span_id=span.span_id, |
374 |
| - outputs=output, |
375 |
| - end_time_ns=end_time_ns, |
376 |
| - ) |
377 |
| - else: |
378 |
| - span = client.start_trace(**common_params) |
379 |
| - client.end_trace( |
380 |
| - request_id=span.request_id, outputs=output, end_time_ns=end_time_ns |
381 |
| - ) |
382 |
| - except Exception as e: |
383 |
| - # Ignoring exceptions because auto-tracing retriever is not essential functionality |
384 |
| - _logger.debug( |
385 |
| - f"Skipping tracing {function_name} as a retriever because of the following error:\n {e}" |
| 366 | + result = None |
| 367 | + |
| 368 | + @mlflow.trace(name=function_info.full_name, span_type=SpanType.RETRIEVER) |
| 369 | + def execute_retriever(parameters): |
| 370 | + # Set inputs manually so we log {"query": "..."} instead of {"parameters": {"query": "..."}} |
| 371 | + if span := mlflow.get_current_active_span(): |
| 372 | + span.set_inputs(parameters) |
| 373 | + |
| 374 | + nonlocal result |
| 375 | + result = _execute_uc_function(function_info, parameters, **kwargs) |
| 376 | + |
| 377 | + # Re-raise errors so they can get traced |
| 378 | + if result.error: |
| 379 | + raise Exception(result.error) |
| 380 | + |
| 381 | + return process_retriever_output(result) |
| 382 | + |
| 383 | + try: |
| 384 | + execute_retriever(parameters) |
| 385 | + except Exception: # Catch all errors that are re-raised |
| 386 | + pass |
| 387 | + |
| 388 | + return result |
| 389 | + except ImportError as e: |
| 390 | + _logger.warn( |
| 391 | + f"Skipping tracing {function_info.full_name} as a retriever because of the following error:\n {e}" |
386 | 392 | )
|
387 |
| - pass |
| 393 | + return _execute_uc_function(function_info, parameters, **kwargs) |
0 commit comments