Skip to content

Commit 9253203

Browse files
Fix execute_function string parsing with serverless (unitycatalog#689)
**PR Checklist** - [ ] A description of the changes is added to the description of this PR. - [ ] If there is a related issue, make sure it is linked to this PR. - [ ] If you've fixed a bug or added code that should be tested, add tests! - [ ] If you've added or modified a feature, documentation in `docs` is updated **Description of changes** Update the logic of serverless execution to pass arguments using `:` prefix by SQL literals. Wrap the exception to provide better error message when failing. Drop the previous sanitization logic as we no longer need it :) SQL side can deal with the single+double quotes combination correctly. Updated tests. <!-- Please state what you've changed and how it might affect the users. --> --------- Signed-off-by: serena-ruan_data <[email protected]> Signed-off-by: Serena Ruan <[email protected]> Co-authored-by: Ben Wilson <[email protected]>
1 parent f003929 commit 9253203

File tree

6 files changed

+114
-258
lines changed

6 files changed

+114
-258
lines changed

ai/core/src/unitycatalog/ai/core/databricks.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
)
2323
from unitycatalog.ai.core.paged_list import PagedList
2424
from unitycatalog.ai.core.utils.callable_utils import generate_sql_function_body
25-
from unitycatalog.ai.core.utils.function_processing_utils import (
26-
sanitize_string_inputs_of_function_params,
27-
)
2825
from unitycatalog.ai.core.utils.type_utils import (
2926
column_type_to_python_type,
3027
convert_timedelta_to_interval_str,
@@ -146,7 +143,15 @@ def retry_on_session_expiration(func):
146143
def wrapper(self, *args, **kwargs):
147144
for attempt in range(1, max_attempts + 1):
148145
try:
149-
return func(self, *args, **kwargs)
146+
result = func(self, *args, **kwargs)
147+
# for non-session related error in the result, we should directly return the result
148+
if (
149+
isinstance(result, FunctionExecutionResult)
150+
and result.error
151+
and SESSION_EXCEPTION_MESSAGE in result.error
152+
):
153+
raise Exception(result.error)
154+
return result
150155
except Exception as e:
151156
error_message = str(e)
152157
if SESSION_EXCEPTION_MESSAGE in error_message:
@@ -657,7 +662,11 @@ def _execute_uc_functions_with_serverless(
657662
_logger.info("Using databricks connect to execute functions with serverless compute.")
658663
self.set_default_spark_session()
659664
sql_command = get_execute_function_sql_command(function_info, parameters)
660-
result = self.spark.sql(sqlQuery=sql_command)
665+
try:
666+
result = self.spark.sql(sqlQuery=sql_command.sql_query, args=sql_command.args or None)
667+
except Exception as e:
668+
error = f"Failed to execute function with command `{sql_command}`; Error: {e}"
669+
return FunctionExecutionResult(error=error)
661670
if is_scalar(function_info):
662671
return FunctionExecutionResult(format="SCALAR", value=str(result.collect()[0][0]))
663672
else:
@@ -813,7 +822,15 @@ def get_execute_function_sql_stmt(
813822
return ParameterizedStatement(statement=statement, parameters=output_params)
814823

815824

816-
def get_execute_function_sql_command(function: "FunctionInfo", parameters: Dict[str, Any]) -> str:
825+
@dataclass
826+
class SparkSqlCommand:
827+
sql_query: str
828+
args: dict[str, Any]
829+
830+
831+
def get_execute_function_sql_command(
832+
function: "FunctionInfo", parameters: Dict[str, Any]
833+
) -> SparkSqlCommand:
817834
from databricks.sdk.service.catalog import ColumnTypeName
818835

819836
sql_query = ""
@@ -824,6 +841,7 @@ def get_execute_function_sql_command(function: "FunctionInfo", parameters: Dict[
824841
f"SELECT * FROM `{function.catalog_name}`.`{function.schema_name}`.`{function.name}`("
825842
)
826843

844+
params_dict: dict[str, Any] = {}
827845
if parameters and function.input_params and function.input_params.parameters:
828846
args: List[str] = []
829847
use_named_args = False
@@ -865,11 +883,9 @@ def get_execute_function_sql_command(function: "FunctionInfo", parameters: Dict[
865883
param_value, Decimal
866884
):
867885
param_value = float(param_value)
868-
# Handle all other types as string types and santitize escape characters
869-
# since this is likely a code block being executed
870-
param_value = sanitize_string_inputs_of_function_params(param_value)
871-
arg_clause += f"'{param_value}'"
886+
arg_clause += f":{param_info.name}"
887+
params_dict[param_info.name] = param_value
872888
args.append(arg_clause)
873889
sql_query += ",".join(args)
874890
sql_query += ")"
875-
return sql_query
891+
return SparkSqlCommand(sql_query=sql_query, args=params_dict)

ai/core/src/unitycatalog/ai/core/utils/function_processing_utils.py

-61
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import ast
21
import decimal
32
import json
43
import logging
@@ -283,63 +282,3 @@ def supported_function_info_types():
283282
pass
284283

285284
return types
286-
287-
288-
def is_python_code(code_str: str) -> bool:
289-
"""Check if the provided string is valid Python code."""
290-
try:
291-
ast.parse(code_str)
292-
return True
293-
except SyntaxError:
294-
return False
295-
296-
297-
def convert_quoting_to_sql_safe_format(string_value: str) -> str:
298-
"""
299-
Convert a string to a SQL-safe format by escaping single quotes.
300-
301-
Args:
302-
string_value: The string to be converted.
303-
304-
Returns:
305-
str: The SQL-safe string.
306-
"""
307-
has_single_quote = "'" in string_value
308-
has_double_quote = '"' in string_value
309-
310-
if not has_single_quote and not has_double_quote:
311-
return string_value
312-
313-
if has_single_quote and not has_double_quote:
314-
string_value = string_value.replace("'", '"')
315-
elif has_single_quote and has_double_quote:
316-
raise ValueError(
317-
"The argument passed in has been detected as Python code that contains both single and double quotes. "
318-
"This is not supported. Code must use only one style of quotation. Please fix the code and try again."
319-
)
320-
return string_value
321-
322-
323-
def sanitize_string_inputs_of_function_params(param_value: Any) -> str:
324-
"""
325-
Sanitize string inputs of function parameters to allow for code block submission.
326-
327-
Args:
328-
param_value: The value of the parameter to sanitize.
329-
330-
Returns:
331-
A sanitized string of the argument value.
332-
"""
333-
334-
if isinstance(param_value, str) and is_python_code(param_value):
335-
# Escape single quotes, backslashes, and control characters that would otherwise break Python code execution
336-
parsed = (
337-
param_value.replace("\\", "\\\\")
338-
.replace("\r", "\\r")
339-
.replace("\n", "\\n")
340-
.replace("\t", "\\t")
341-
)
342-
quotes_parsed = convert_quoting_to_sql_safe_format(parsed)
343-
else:
344-
quotes_parsed = param_value
345-
return str(quotes_parsed)

ai/core/src/unitycatalog/ai/test_utils/function_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def create_python_function_and_cleanup(
8686
) -> Generator[FunctionObj, None, None]:
8787
func_name = f"{CATALOG}.{schema}.{func.__name__}"
8888
try:
89-
func_info = client.create_python_function(func=func, catalog=CATALOG, schema=schema)
89+
func_info = client.create_python_function(
90+
func=func, catalog=CATALOG, schema=schema, replace=True
91+
)
9092
yield FunctionObj(
9193
full_function_name=func_name,
9294
comment=func_info.comment,

ai/core/tests/core/databricks/test_databricks_integration_tests.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import os
23
import time
34
from typing import Callable, Dict, List
@@ -38,13 +39,16 @@
3839
UCAI_DATABRICKS_WAREHOUSE_RETRY_TIMEOUT,
3940
)
4041
from unitycatalog.ai.test_utils.client_utils import (
42+
USE_SERVERLESS,
4143
client, # noqa: F401
44+
get_client,
4245
requires_databricks,
4346
retry_flaky_test,
4447
serverless_client, # noqa: F401
4548
)
4649
from unitycatalog.ai.test_utils.function_utils import (
4750
CATALOG,
51+
create_function_and_cleanup,
4852
create_python_function_and_cleanup,
4953
generate_func_name_and_cleanup,
5054
random_func_name,
@@ -418,14 +422,60 @@ def simple_func(x: int) -> str:
418422
print(calculate_sum([1, 2, 3, 4, 5]))""",
419423
"15\n",
420424
),
425+
# Simple print statement
426+
("print('Hello, world!')", "Hello, world!\n"),
427+
# Code with double quotes
428+
('print("He said, \\"Hi!\\"")', 'He said, "Hi!"\n'),
429+
# Code with backslashes
430+
(r"print('C:\\path\\into\\dir')", "C:\\path\\into\\dir\n"),
431+
# Multi-line code with newlines
432+
("for i in range(3):\n print(i)", "0\n1\n2\n"),
433+
# Code with tabs and indents
434+
("def greet(name):\n print(f'Hello, {name}!')\ngreet('Alice')", "Hello, Alice!\n"),
435+
# Code with special characters
436+
("print('Special chars: !@#$%^&*()')", "Special chars: !@#$%^&*()\n"),
437+
# Unicode characters
438+
("print('Unicode test: ü, é, 漢字')", "Unicode test: ü, é, 漢字\n"),
439+
# Code with comments
440+
("# This is a comment\nprint('Comment test')", "Comment test\n"),
441+
# Code raising an exception
442+
(
443+
"try:\n raise ValueError('Test error')\nexcept Exception as e:\n print(f'Caught an error: {e}')",
444+
"Caught an error: Test error\n",
445+
),
446+
# Code with triple quotes
447+
('print("""Triple quote test""")', "Triple quote test\n"),
448+
# Code with raw strings
449+
("print('Raw string: \\\\n new line')", "Raw string: \\n new line\n"),
450+
# Empty code string
451+
("", ""),
452+
# Code with carriage return
453+
("print('Line1\\\\rLine2')", "Line1\\rLine2\n"),
454+
# Code with encoding declarations (Note: encoding declarations should be in the first or second line)
455+
("# -*- coding: utf-8 -*-\nprint('Encoding test')", "Encoding test\n"),
456+
# Code importing a standard library
457+
("import math\nprint(math.pi)", f"{math.pi}\n"),
458+
# Code with nested functions
459+
(
460+
"def outer():\n def inner():\n return 'Nested'\n return inner()\nprint(outer())",
461+
"Nested\n",
462+
),
463+
# Code with list comprehensions
464+
("squares = [x**2 for x in range(5)]\nprint(squares)", "[0, 1, 4, 9, 16]\n"),
465+
# Code with multi-line strings
466+
("multi_line = '''Line1\nLine2\nLine3'''\nprint(multi_line)", "Line1\nLine2\nLine3\n"),
421467
]
422468

423469

424470
@requires_databricks
425471
@pytest.mark.parametrize("code, expected_output", integration_test_cases)
472+
@pytest.mark.parametrize("use_serverless", [True, False])
426473
def test_execute_python_code_integration(
427-
client: DatabricksFunctionClient, code: str, expected_output: str
474+
code: str, expected_output: str, use_serverless: bool, monkeypatch
428475
):
476+
monkeypatch.setenv(USE_SERVERLESS, str(use_serverless))
477+
client = get_client()
478+
429479
def python_exec(code: str) -> str:
430480
"""
431481
Execute the provided Python code and return the output.
@@ -451,3 +501,28 @@ def python_exec(code: str) -> str:
451501
assert result.error is None, f"Function execution failed with error: {result.error}"
452502

453503
assert result.value == expected_output
504+
505+
506+
@requires_databricks
507+
@pytest.mark.parametrize("use_serverless", [True, False])
508+
@pytest.mark.parametrize(
509+
"text",
510+
[
511+
"MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It was developed by Databricks and is now a part of the Linux Foundation's AI Foundation.",
512+
"print('Hello, \"world!\"')",
513+
"'return '2' + \"" '3"' "' is a valid input to this function",
514+
],
515+
)
516+
def test_string_param_passing_work(text: str, use_serverless: bool, monkeypatch):
517+
monkeypatch.setenv(USE_SERVERLESS, str(use_serverless))
518+
client = get_client()
519+
function_name = random_func_name(schema=SCHEMA)
520+
summarize_in_20_words = f"""CREATE OR REPLACE FUNCTION {function_name}(text STRING)
521+
RETURNS STRING
522+
RETURN SELECT ai_summarize(text, 20)
523+
"""
524+
with create_function_and_cleanup(client=client, schema=SCHEMA, sql_body=summarize_in_20_words):
525+
result = client.execute_function(function_name, {"text": text})
526+
assert result.error is None, f"Function execution failed with error: {result.error}"
527+
# number of words should be no more than 20
528+
assert len(result.value.split(" ")) <= 20

0 commit comments

Comments
 (0)