Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,11 @@ HUBSPOT_CONNECTED_ACCOUNT_ID=
OPENAI_API_KEY=
SUPERFACE_API_KEY=
SUPERFACE_BASE_URL=https://pod.superface.ai
TEST_PROMPT=Create a new lead, John Doe ([email protected]), and the company ACME Ltd (acme.com). Check for duplicate companies by name.
TEST_PROMPT=Create a new lead, John Doe ([email protected]), and the company ACME Ltd (acme.com). Check for duplicate companies by name.

# Benchmark environment variables
BENCHMARK_MODEL=gpt-4o
SUPERFACE_BENCHMARK_USER_ID=benchmark_test
BENCHMARK_RUNS_PER_TEST=3
BENCHMARK_TEMPERATURE=0.1
BENCHMARK_SEED=42
124 changes: 124 additions & 0 deletions benchmarks/benchmark_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import json
import statistics
from datetime import datetime

def compile_benchmark_results(benchmark_results, bench_data, num_runs_per_test, benchmark_name="Superface Specialist Benchmark", environment=None):
"""
Compile benchmark results into a structured format with summary statistics.

Args:
benchmark_results: List of individual test results
bench_data: Original benchmark test data
num_runs_per_test: Number of runs per test case
benchmark_name: Name of the benchmark
environment: Dictionary containing environment variables to include in results

Returns:
Dictionary containing compiled results and statistics
"""
# Compile final results
final_results = {
"timestamp": datetime.now().isoformat(),
"benchmark_name": benchmark_name,
"total_tests": len(bench_data),
"successful_tests": sum(1 for r in benchmark_results if r['success']),
"failed_tests": sum(1 for r in benchmark_results if not r['success']),
"tool_calls_match_count": sum(1 for r in benchmark_results if r.get('tool_calls_match', False)),
"success_rate": (sum(1 for r in benchmark_results if r['success']) / len(bench_data)) * 100,
"tool_call_match_rate": (sum(1 for r in benchmark_results if r.get('tool_calls_match', False)) / len(bench_data)) * 100
}

# Add environment variables to the results if provided
if environment:
final_results["environment"] = environment

# For JSON serialization, convert CrewOutput objects to strings
serializable_results = []
for r in benchmark_results:
serializable_result = r.copy()
if r['success'] and r['result'] is not None:
serializable_result['result'] = str(r['result'])
# Make sure actual_tool_calls is included
if 'actual_tool_calls' not in serializable_result:
serializable_result['actual_tool_calls'] = []
serializable_results.append(serializable_result)

# Update the final_results with serializable results
final_results["results"] = serializable_results

# Calculate additional summary statistics
run_durations = [r['duration_seconds'] for r in benchmark_results]
if run_durations:
final_results["summary_stats"] = {
"avg_test_duration": statistics.mean(run_durations),
"min_test_duration": min(run_durations),
"max_test_duration": max(run_durations),
"std_dev_test_duration": statistics.stdev(run_durations) if len(run_durations) > 1 else 0,
"total_benchmark_duration": sum(run_durations),
"runs_per_test": num_runs_per_test
}

# Collect all mismatch reasons across tests
all_mismatch_reasons = {}
for result in benchmark_results:
if 'stats' in result and 'mismatch_summary' in result['stats']:
for reason, count in result['stats']['mismatch_summary'].items():
all_mismatch_reasons[reason] = all_mismatch_reasons.get(reason, 0) + count

final_results["summary_stats"]["mismatch_reasons"] = all_mismatch_reasons

return final_results

def save_benchmark_results(final_results, prefix="benchmark_results"):
"""
Save benchmark results to a JSON file.

Args:
final_results: Dictionary containing benchmark results
prefix: Prefix for the filename

Returns:
Path to the saved file
"""
results_filename = f"./results/{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"

# Ensure the results directory exists
import os
os.makedirs("./results", exist_ok=True)

# Make sure success_rate and tool_call_match_rate are included
if 'success_rate' not in final_results:
final_results['success_rate'] = (final_results['successful_tests'] / final_results['total_tests']) * 100

if 'tool_call_match_rate' not in final_results:
final_results['tool_call_match_rate'] = (final_results['tool_calls_match_count'] / final_results['total_tests']) * 100

with open(results_filename, "w") as f:
json.dump(final_results, f, indent=2)

return results_filename

def print_benchmark_summary(final_results):
"""
Print a summary of benchmark results.

Args:
final_results: Dictionary containing benchmark results
"""
print("\n=== Benchmark Summary ===")
print(f"Total tests: {final_results['total_tests']}")
print(f"Successful tests: {final_results['successful_tests']}")
print(f"Failed tests: {final_results['failed_tests']}")
print(f"Tests with matching tool calls: {final_results['tool_calls_match_count']}")
print(f"Success rate: {final_results['success_rate']:.2f}%")
print(f"Tool call match rate: {final_results['tool_call_match_rate']:.2f}%")

if "summary_stats" in final_results:
stats = final_results["summary_stats"]
print(f"\nAverage test duration: {stats['avg_test_duration']:.2f} seconds")
print(f"Total benchmark duration: {stats['total_benchmark_duration']:.2f} seconds")

if "mismatch_reasons" in stats and stats["mismatch_reasons"]:
print("\nMismatch reasons:")
for reason, count in stats["mismatch_reasons"].items():
print(f"- {reason} (occurred {count} times)")
224 changes: 224 additions & 0 deletions benchmarks/compare_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import json
import re

def compare_tool_calls(expected_calls, actual_calls):
"""
Compare expected tool calls with actual tool calls, allowing for flexible matching.

Args:
expected_calls (list): List of expected tool call dictionaries. Each call can include:
- optional_tool_input_keys: List of keys that are optional in the input
- optional: Boolean indicating if this tool call is optional
- input_patterns: Dict of keys whose values should be matched as regexps
- any_order: Boolean indicating if the order of this call is flexible
- allow_additional_keys: Boolean indicating if additional keys are allowed in actual input
actual_calls (list): List of actual tool call dictionaries

Returns:
tuple: (bool, str) - (success, reason)
"""
# Track which actual calls have been matched
remaining_actual_calls = actual_calls.copy()

# First handle required calls, then optional ones
required_calls = [call for call in expected_calls if not call.get('optional', False)]
optional_calls = [call for call in expected_calls if call.get('optional', False)]

if len(required_calls) > len(actual_calls):
return False, f"Not enough actual calls to match required calls. Need {len(required_calls)}, got {len(actual_calls)}"

for i, expected in enumerate(required_calls + optional_calls):
is_optional = expected.get('optional', False)
any_order = expected.get('any_order', False)
allow_additional = expected.get('allow_additional_keys', False)
input_patterns = expected.get('input_patterns', {})
optional_keys = expected.get('optional_tool_input_keys', [])

# Skip optional calls if no more actual calls available
if is_optional and not any(call is not None for call in remaining_actual_calls):
continue

match_found = False
for j, actual in enumerate(remaining_actual_calls):
if actual is None:
continue

match_result = _match_tool_call(
expected,
actual,
optional_keys=optional_keys,
input_patterns=input_patterns,
allow_additional=allow_additional
)

if match_result:
match_found = True
remaining_actual_calls[j] = None
break

if not match_found and not is_optional:
# Include the function name in the error message for better debugging
function_name = expected.get('tool_name', 'unknown')
return False, f"No matching actual call found for {'optional' if is_optional else 'required'} call to function '{function_name}'"

# Check if there are unmatched actual calls
unmatched = [call for call in remaining_actual_calls if call is not None]

return True, "All tool calls match"

def _match_tool_call(expected, actual, optional_keys=None, input_patterns=None, allow_additional=False):
"""Helper function to match a single tool call"""

optional_keys = optional_keys or []
input_patterns = input_patterns or {}

# Check tool name
if expected['tool_name'] != actual['tool_name']:
return False

# Check status
if expected['tool_output']['status'] != actual['tool_output']['status']:
return False

expected_input = json.loads(expected['tool_input']) if isinstance(expected['tool_input'], str) else expected['tool_input']
actual_input = json.loads(actual['tool_input']) if isinstance(actual['tool_input'], str) else actual['tool_input']

# Check if non-optional keys exist
required_keys = set(expected_input.keys()) - set(optional_keys)
missing_keys = [key for key in required_keys if key not in actual_input]
if missing_keys:
return False

# Check if there are unexpected keys (when not allowed)
unexpected_keys = set(actual_input.keys()) - set(optional_keys) - set(expected_input.keys())
if not allow_additional and unexpected_keys:
return False

# Check each expected key
for key, expected_value in expected_input.items():
# Skip optional keys not in actual input
if key in optional_keys and key not in actual_input:
continue

if key not in actual_input:
return False

actual_value = actual_input[key]

# Check if there are any deep patterns for this key
deep_patterns = {k: v for k, v in input_patterns.items() if k.startswith(f"{key}.") or k.startswith(f"{key}[")}

# If we have deep patterns for this key, use deep comparison
if deep_patterns:
pattern_dict = {}
for pattern_key, pattern_value in deep_patterns.items():
# Handle both dot notation (obj.prop) and array notation (obj[0])
if pattern_key.startswith(f"{key}."):
# For dot notation: "associations.types.associationTypeId"
path = pattern_key.split(".", 1)[1] # Remove the root key part
elif pattern_key.startswith(f"{key}["):
# For array notation: "associations[0]"
path = pattern_key[len(key):] # Remove the root key part
else:
# Should not happen due to the filter above
continue

pattern_dict[path] = pattern_value

if not _deep_compare(expected_value, actual_value, pattern_dict):
return False

continue

# Simple pattern matching for string values
if key in input_patterns:
pattern = input_patterns[key]

if pattern is None:
# Skip pattern matching if pattern is None
continue

if not isinstance(actual_value, str):
return False

match = re.match(pattern, actual_value)

if not match:
return False
# Regular value comparison
elif expected_value != actual_value:
return False

return True

def _deep_compare(expected, actual, pattern_dict, current_path=""):
"""
Recursively compare objects with support for regex patterns at specific paths.

Args:
expected: The expected value (can be dict, list, or primitive)
actual: The actual value to compare against
pattern_dict: Dictionary mapping paths to regex patterns
current_path: Current path in the object hierarchy

Returns:
bool: True if objects match, False otherwise
"""
# If types don't match, fail immediately (except for numeric types)
if type(expected) != type(actual) and not (
isinstance(expected, (int, float)) and isinstance(actual, (int, float))
):
return False

# Handle dictionaries
if isinstance(expected, dict):
# Check all expected keys exist in actual
for key in expected:
if key not in actual:
return False

# Recursively check each key
for key in expected:
new_path = f"{current_path}.{key}" if current_path else key
if not _deep_compare(expected[key], actual[key], pattern_dict, new_path):
return False

return True

# Handle lists
elif isinstance(expected, list):
if len(expected) != len(actual):
return False

# Recursively check each item
for i, (exp_item, act_item) in enumerate(zip(expected, actual)):
new_path = f"{current_path}[{i}]"
if not _deep_compare(exp_item, act_item, pattern_dict, new_path):
return False

return True

# Handle primitive values
else:
# Check if we have a pattern for this path
if current_path in pattern_dict:
pattern = pattern_dict[current_path]

# Skip if pattern is None
if pattern is None:
return True

# For regex pattern matching, convert actual to string if needed
actual_str = str(actual)
match = re.match(pattern, actual_str)

if not match:
return False

return True
else:
# Regular equality check
if expected != actual:
return False

return True
Loading