Skip to content

Commit 661b080

Browse files
authored
Merge pull request #181 from bsatapat-jpg/dev_biswajit
[LEADS-218] Improve partial tool match reporting with detailed extra tool information
2 parents adc2689 + 93b7143 commit 661b080

3 files changed

Lines changed: 250 additions & 39 deletions

File tree

src/lightspeed_evaluation/core/metrics/custom/tool_eval.py

Lines changed: 155 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def evaluate_tool_calls(
2626
tuple: (success: bool, details: str)
2727
"""
2828
try:
29-
# Try each set until one matches
29+
# Try each set until one matches, track best result for failure reporting
30+
best_result: dict[str, Any] = {}
31+
best_matched = -1
32+
3033
for i, expected_set in enumerate(expected):
3134
result = compare_tool_calls(
3235
expected_set, actual, ordered=ordered, full_match=full_match
@@ -40,8 +43,21 @@ def evaluate_tool_calls(
4043
match_stats=result.get("stats"),
4144
)
4245

43-
# If all sets fail, return failure status & message
44-
return _create_failure_message(expected, actual)
46+
# Track the best matching alternative (most matched tools)
47+
stats = result.get("stats", {})
48+
matched = stats.get("matched", 0)
49+
if matched > best_matched:
50+
best_matched = matched
51+
best_result = result
52+
53+
# If all sets fail, return failure with best match stats
54+
return _create_failure_message(
55+
expected,
56+
actual,
57+
ordered=ordered,
58+
full_match=full_match,
59+
best_stats=best_result.get("stats"),
60+
)
4561

4662
except (AttributeError, TypeError, ValueError) as e:
4763
logger.error("Error during tool evaluation: %s", e)
@@ -82,12 +98,20 @@ def compare_tool_calls(
8298
mismatch_message = "Tool calls count mismatch: expected %d, got %d"
8399

84100
if not full_match:
85-
matched, total = _compare_partial(expected_normalized, actual_normalized)
101+
matched, total, extra_tools, unmatched_expected = _compare_partial(
102+
expected_normalized, actual_normalized
103+
)
86104
# Partial match succeeds if all expected tools matched (subset matching)
87105
success = matched == total
88106
return {
89107
"success": success,
90-
"stats": {"matched": matched, "total": total, "unmatched": total - matched},
108+
"stats": {
109+
"matched": matched,
110+
"total": total,
111+
"unmatched": total - matched,
112+
"extra_actual_tools": extra_tools,
113+
"unmatched_expected_tools": unmatched_expected,
114+
},
91115
}
92116

93117
# Full match (default)
@@ -121,7 +145,7 @@ def _normalize_sequences(
121145
def _compare_partial(
122146
expected: list[list[dict[str, Any]]],
123147
actual: list[list[dict[str, Any]]],
124-
) -> tuple[int, int]:
148+
) -> tuple[int, int, list[str], list[str]]:
125149
"""Compare tool calls with partial matching.
126150
127151
Counts how many expected sequences are found in actual.
@@ -136,29 +160,50 @@ def _compare_partial(
136160
actual: Actual tool call sequences (pre-normalized)
137161
138162
Returns:
139-
Tuple of (matched_count, total_expected)
163+
Tuple of (matched_count, total_expected, extra_actual_tools, unmatched_expected_tools)
140164
"""
141165
if not expected:
142-
return (0, 0)
166+
extra_tools = [_get_sequence_tool_names(seq) for seq in actual]
167+
return (0, 0, extra_tools, [])
143168

144169
matched = 0
145170
used_indices: set[int] = set()
171+
matched_expected_indices: set[int] = set()
146172

147-
for expected_seq in expected:
173+
for i, expected_seq in enumerate(expected):
148174
for j, actual_seq in enumerate(actual):
149175
if j not in used_indices and _compare_tool_call_sequence(
150176
expected_seq, actual_seq
151177
):
152178
matched += 1
153179
used_indices.add(j)
180+
matched_expected_indices.add(i)
154181
break
155182

183+
extra_tools = [
184+
_get_sequence_tool_names(actual[i])
185+
for i in range(len(actual))
186+
if i not in used_indices
187+
]
188+
unmatched_expected = [
189+
_get_sequence_tool_names(expected[i])
190+
for i in range(len(expected))
191+
if i not in matched_expected_indices
192+
]
156193
logger.debug(
157-
"Partial match: %d/%d expected sequences found",
194+
"Partial match: %d/%d expected sequences found, extra: %s, unmatched: %s",
158195
matched,
159196
len(expected),
197+
extra_tools,
198+
unmatched_expected,
160199
)
161-
return (matched, len(expected))
200+
return (matched, len(expected), extra_tools, unmatched_expected)
201+
202+
203+
def _get_sequence_tool_names(sequence: list[dict[str, Any]]) -> str:
204+
"""Get tool names from a sequence as a comma-separated string."""
205+
names = [tc.get("tool_name", "unknown") for tc in sequence]
206+
return ", ".join(names) if len(names) > 1 else (names[0] if names else "unknown")
162207

163208

164209
def _compare_tool_call_sequence(
@@ -371,12 +416,77 @@ def _compare_tool_result(expected: dict[str, Any], actual: dict[str, Any]) -> bo
371416
return True
372417

373418

419+
def _get_mode_suffix(ordered: bool, full_match: bool) -> str:
420+
"""Get the mode suffix string for messages.
421+
422+
Args:
423+
ordered: Whether ordered matching was used
424+
full_match: Whether full or partial matching was used
425+
426+
Returns:
427+
Mode suffix like "(partial, ordered)"
428+
"""
429+
match_mode = "full" if full_match else "partial"
430+
order_mode = "ordered" if ordered else "unordered"
431+
return f"({match_mode}, {order_mode})"
432+
433+
434+
def _format_match_stats(
435+
match_stats: dict[str, Any],
436+
ordered: bool,
437+
full_match: bool,
438+
) -> str:
439+
"""Format match statistics into a human-readable string.
440+
441+
Args:
442+
match_stats: Dict with matched/total/unmatched/extra_actual_tools/unmatched_expected_tools
443+
ordered: Whether ordered matching was used
444+
full_match: Whether full or partial matching was used
445+
446+
Returns:
447+
Formatted statistics string
448+
"""
449+
matched = match_stats["matched"]
450+
total = match_stats["total"]
451+
unmatched = match_stats["unmatched"]
452+
453+
# Extra actual tools (from actual that weren't used)
454+
extra_tools: list[str] = match_stats.get("extra_actual_tools", [])
455+
extra_count = len(extra_tools)
456+
extra_info = f"[{', '.join(extra_tools)}]" if extra_tools else "none"
457+
458+
# Handle empty expected (no tool calls expected for this alternative)
459+
if total == 0:
460+
if extra_count > 0:
461+
return (
462+
f"No expected tool calls (skip scenario), "
463+
f"but {extra_count} actual: {extra_info} "
464+
f"{_get_mode_suffix(ordered, full_match)}"
465+
)
466+
return (
467+
f"No expected tool calls (skip scenario) "
468+
f"{_get_mode_suffix(ordered, full_match)}"
469+
)
470+
471+
# Unmatched expected tools (from expected that didn't match)
472+
unmatched_expected: list[str] = match_stats.get("unmatched_expected_tools", [])
473+
unmatched_info = (
474+
f"[{', '.join(unmatched_expected)}]" if unmatched_expected else "none"
475+
)
476+
477+
return (
478+
f"{matched}/{total} expected matched, {unmatched} unmatched: {unmatched_info}, "
479+
f"{extra_count} extra in response: {extra_info} "
480+
f"{_get_mode_suffix(ordered, full_match)}"
481+
)
482+
483+
374484
def _create_success_message(
375485
index: int,
376486
expected_set: list[list[dict[str, Any]]],
377487
ordered: bool = True,
378488
full_match: bool = True,
379-
match_stats: dict[str, int] | None = None,
489+
match_stats: dict[str, Any] | None = None,
380490
) -> tuple[bool, str]:
381491
"""Create success message based on match type.
382492
@@ -385,54 +495,65 @@ def _create_success_message(
385495
expected_set: The matched expected tool call set
386496
ordered: Whether ordered matching was used
387497
full_match: Whether full or partial matching was used
388-
match_stats: Optional dict with matched/total/unmatched counts for partial match
498+
match_stats: Optional dict with matched/total/unmatched/extra_actual_tools
389499
390500
Returns:
391501
Tuple of (True, success message)
392502
"""
393503
pattern_type = "Primary pattern" if index == 0 else f"Alternative {index + 1}"
394-
order_mode = "ordered" if ordered else "unordered"
395-
match_mode = "full" if full_match else "partial"
396504

397-
# Determine message based on what matched
398-
if len(expected_set) == 0:
399-
# Empty alternative matched - index 0 can never be empty due to constraints
505+
# Check match_stats first to include extra tools info (for partial match mode)
506+
if match_stats:
507+
message = f"Tool calls: {_format_match_stats(match_stats, ordered, full_match)}"
508+
elif len(expected_set) == 0:
400509
message = "No tool calls made (valid alternate skip scenario)"
401-
elif match_stats:
402-
# Include match statistics for partial match
403-
matched = match_stats["matched"]
404-
total = match_stats["total"]
405-
unmatched = match_stats["unmatched"]
406-
message = (
407-
f"Tool calls: {matched}/{total} matched, {unmatched} unmatched "
408-
f"({match_mode}, {order_mode})"
409-
)
410510
else:
411511
message = (
412512
f"Tool calls match expected structure and arguments "
413-
f"({match_mode}, {order_mode})"
513+
f"{_get_mode_suffix(ordered, full_match)}"
414514
)
415515

416516
return True, f"{pattern_type} matched: {message}"
417517

418518

419519
def _create_failure_message(
420-
expected: list[list[list[dict[str, Any]]]], actual: list[list[dict[str, Any]]]
520+
expected: list[list[list[dict[str, Any]]]],
521+
actual: list[list[dict[str, Any]]],
522+
ordered: bool = True,
523+
full_match: bool = True,
524+
best_stats: dict[str, Any] | None = None,
421525
) -> tuple[bool, str]:
422-
"""Create failure message with helpful context."""
423-
# If we reach here, none of the alternatives matched
526+
"""Create failure message with helpful context.
424527
528+
Args:
529+
expected: Expected tool call patterns (with alternatives)
530+
actual: Actual tool calls from API response
531+
ordered: Whether ordered matching was used
532+
full_match: Whether full or partial matching was used
533+
best_stats: Stats from best matching alternative (most matched tools)
534+
535+
Returns:
536+
Tuple of (False, failure message)
537+
"""
425538
if len(actual) == 0:
426539
return (
427540
False,
428541
"No actual tool calls made and this is not set as an expected alternative",
429542
)
430543

431-
return (
432-
False,
433-
f"Tool calls made but didn't match any of the {len(expected)} expected pattern(s)",
544+
base_msg = (
545+
f"Tool calls made but didn't match any of the {len(expected)} "
546+
f"expected pattern(s)"
434547
)
435548

549+
if best_stats and not full_match:
550+
return (
551+
False,
552+
f"{base_msg}: {_format_match_stats(best_stats, ordered, full_match)}",
553+
)
554+
555+
return (False, base_msg)
556+
436557

437558
def format_tool_calls_for_logging(tool_calls: list[list[dict[str, Any]]]) -> str:
438559
"""Format tool calls for logging purposes."""

tests/unit/core/metrics/custom/test_custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_config_match_partial_from_metadata(self, mocker: MockerFixture) -> None
122122

123123
assert score == 1.0
124124
assert "partial" in reason
125-
assert "1/1 matched" in reason
125+
assert "1/1 expected matched" in reason
126126

127127
def test_config_from_system_defaults_via_metric_manager(
128128
self, mocker: MockerFixture

0 commit comments

Comments
 (0)