Skip to content

Commit b7d187a

Browse files
authored
[Bug fix] Support search traces by string feedback / expectation values (mlflow#19719)
Signed-off-by: dbczumar <[email protected]>
1 parent 6b11eb4 commit b7d187a

File tree

3 files changed

+93
-5
lines changed

3 files changed

+93
-5
lines changed

mlflow/store/tracking/sqlalchemy_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5388,14 +5388,14 @@ def _get_filter_clauses_for_search_traces(filter_string, session, dialect):
53885388
span_filters.append(span_subquery)
53895389
continue
53905390
elif SearchTraceUtils.is_assessment(key_type, key_name, comparator):
5391-
# Create subquery to find traces with matching feedback
5392-
# Filter by feedback name and check the value
5391+
# Create subquery to find traces with matching assessments
5392+
# Filter by assessment name and check the value
53935393
feedback_subquery = (
53945394
session.query(SqlAssessments.trace_id.label("request_id"))
53955395
.filter(
53965396
SqlAssessments.assessment_type == key_type,
53975397
SqlAssessments.name == key_name,
5398-
SearchTraceUtils.get_sql_comparison_func(comparator, dialect)(
5398+
SearchTraceUtils._get_sql_json_comparison_func(comparator, dialect)(
53995399
SqlAssessments.value, value
54005400
),
54015401
)

mlflow/utils/search_utils.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import shlex
88
from dataclasses import asdict, dataclass
9-
from typing import Any
9+
from typing import TYPE_CHECKING, Any, Callable
1010

1111
import sqlparse
1212
from packaging.version import Version
@@ -37,6 +37,12 @@
3737
MLFLOW_DATASET_CONTEXT,
3838
)
3939

40+
if TYPE_CHECKING:
41+
from sqlalchemy.sql.elements import ClauseElement, ColumnElement
42+
43+
# MSSQL collation for case-sensitive string comparisons
44+
_MSSQL_CASE_SENSITIVE_COLLATION = "Japanese_Bushu_Kakusu_100_CS_AS_KS_WS"
45+
4046

4147
def _convert_like_pattern_to_regex(pattern, flags=0):
4248
if not pattern.startswith("%"):
@@ -245,7 +251,7 @@ def mssql_comparison_func(column, value):
245251
if not isinstance(column.type, sa.types.String):
246252
return comparison_func(column, value)
247253

248-
collated = column.collate("Japanese_Bushu_Kakusu_100_CS_AS_KS_WS")
254+
collated = column.collate(_MSSQL_CASE_SENSITIVE_COLLATION)
249255
return comparison_func(collated, value)
250256

251257
def mysql_comparison_func(column, value):
@@ -1861,6 +1867,62 @@ def is_assessment(cls, key_type, key_name, comparator):
18611867
return True
18621868
return False
18631869

1870+
@staticmethod
1871+
def _get_sql_json_comparison_func(
1872+
comparator: str, dialect: str
1873+
) -> Callable[["ColumnElement", str], "ClauseElement"]:
1874+
"""
1875+
Returns a comparison function for JSON-serialized values.
1876+
1877+
Assessment values are stored as JSON primitives in the database:
1878+
- Boolean False -> false (no quotes in JSON)
1879+
- Numeric value 5 -> 5 (no quotes in JSON)
1880+
- String "yes" -> '"yes"' (WITH quotes in JSON)
1881+
1882+
For equality comparisons, we match either the raw JSON primitive value
1883+
(for booleans and numeric values) or the JSON-serialized value (for strings).
1884+
"""
1885+
import sqlalchemy as sa
1886+
1887+
def mysql_json_equality_inequality_comparison(
1888+
column: "ColumnElement", value: str
1889+
) -> "ClauseElement":
1890+
# MySQL is case insensitive by default, so we need to use the BINARY operator
1891+
# for case sensitive comparisons. We check both the raw value (for booleans/numbers)
1892+
# and the JSON-serialized value (for strings).
1893+
json_string_value = json.dumps(value)
1894+
col_ref = f"{column.class_.__tablename__}.{column.key}"
1895+
template = (
1896+
f"(({col_ref} = :value1 AND BINARY {col_ref} = :value1) OR "
1897+
f"({col_ref} = :value2 AND BINARY {col_ref} = :value2))"
1898+
)
1899+
if comparator == "!=":
1900+
template = f"NOT {template}"
1901+
return sa.text(template).bindparams(
1902+
sa.bindparam("value1", value=value, unique=True),
1903+
sa.bindparam("value2", value=json_string_value, unique=True),
1904+
)
1905+
1906+
def json_equality_inequality_comparison(
1907+
column: "ColumnElement", value: str
1908+
) -> "ClauseElement":
1909+
# MSSQL uses collation for case-sensitive comparisons on String columns
1910+
if dialect == MSSQL:
1911+
column = column.collate(_MSSQL_CASE_SENSITIVE_COLLATION)
1912+
1913+
json_string_value = json.dumps(value)
1914+
clause = sa.or_(column == value, column == json_string_value)
1915+
if comparator == "!=":
1916+
clause = sa.not_(clause)
1917+
return clause
1918+
1919+
if comparator not in ("=", "!="):
1920+
return SearchTraceUtils.get_sql_comparison_func(comparator, dialect)
1921+
elif dialect == MYSQL:
1922+
return mysql_json_equality_inequality_comparison
1923+
else:
1924+
return json_equality_inequality_comparison
1925+
18641926
@classmethod
18651927
def _valid_entity_type(cls, entity_type):
18661928
entity_type = cls._trim_backticks(entity_type)

tests/store/tracking/test_sqlalchemy_store.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5524,6 +5524,13 @@ def test_search_traces_with_feedback_and_expectation_filters(store: SqlAlchemySt
55245524
source=AssessmentSource(source_type="HUMAN", source_id="[email protected]"),
55255525
)
55265526

5527+
feedback4 = Feedback(
5528+
trace_id=trace1_id,
5529+
name="quality",
5530+
value="high",
5531+
source=AssessmentSource(source_type="HUMAN", source_id="[email protected]"),
5532+
)
5533+
55275534
# Create expectations for trace3 and trace4
55285535
expectation1 = Expectation(
55295536
trace_id=trace3_id,
@@ -5546,13 +5553,22 @@ def test_search_traces_with_feedback_and_expectation_filters(store: SqlAlchemySt
55465553
source=AssessmentSource(source_type="CODE", source_id="latency_monitor"),
55475554
)
55485555

5556+
expectation4 = Expectation(
5557+
trace_id=trace3_id,
5558+
name="priority",
5559+
value="urgent",
5560+
source=AssessmentSource(source_type="CODE", source_id="priority_checker"),
5561+
)
5562+
55495563
# Store assessments
55505564
store.create_assessment(feedback1)
55515565
store.create_assessment(feedback2)
55525566
store.create_assessment(feedback3)
5567+
store.create_assessment(feedback4)
55535568
store.create_assessment(expectation1)
55545569
store.create_assessment(expectation2)
55555570
store.create_assessment(expectation3)
5571+
store.create_assessment(expectation4)
55565572

55575573
# Test: Search for traces with correctness feedback = True
55585574
traces, _ = store.search_traces([exp_id], filter_string='feedback.correctness = "true"')
@@ -5569,6 +5585,11 @@ def test_search_traces_with_feedback_and_expectation_filters(store: SqlAlchemySt
55695585
assert len(traces) == 1
55705586
assert traces[0].request_id == trace2_id
55715587

5588+
# Test: Search for traces with string-valued feedback
5589+
traces, _ = store.search_traces([exp_id], filter_string='feedback.quality = "high"')
5590+
assert len(traces) == 1
5591+
assert traces[0].request_id == trace1_id
5592+
55725593
# Test: Search for traces with response_length expectation = 150
55735594
traces, _ = store.search_traces([exp_id], filter_string='expectation.response_length = "150"')
55745595
assert len(traces) == 1
@@ -5584,6 +5605,11 @@ def test_search_traces_with_feedback_and_expectation_filters(store: SqlAlchemySt
55845605
assert len(traces) == 1
55855606
assert traces[0].request_id == trace4_id
55865607

5608+
# Test: Search for traces with string-valued expectation
5609+
traces, _ = store.search_traces([exp_id], filter_string='expectation.priority = "urgent"')
5610+
assert len(traces) == 1
5611+
assert traces[0].request_id == trace3_id
5612+
55875613
# Test: Combined filter with AND - trace with multiple expectations
55885614
traces, _ = store.search_traces(
55895615
[exp_id],

0 commit comments

Comments
 (0)