Skip to content

Commit 395806c

Browse files
committed
add initial version of revamped coded evaluators
1 parent 84ae313 commit 395806c

20 files changed

Lines changed: 2406 additions & 12 deletions

src/uipath/eval/_helpers/helpers.py

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
11
import json
22
import os
3+
from collections.abc import Mapping
4+
from datetime import datetime
5+
from typing import Any, Sequence
36

47
import click
8+
from opentelemetry.sdk.trace import ReadableSpan
59

610
from uipath._cli._utils._console import ConsoleLogger
711
from uipath._utils.constants import UIPATH_CONFIG_FILE
812

13+
COMPARATOR_MAPPINGS = {
14+
">": "gt",
15+
"<": "lt",
16+
">=": "ge",
17+
"<=": "le",
18+
"=": "eq",
19+
"!=": "ne",
20+
}
21+
922

1023
def auto_discover_entrypoint() -> str:
1124
"""Auto-discover entrypoint from config file.
@@ -45,3 +58,347 @@ def auto_discover_entrypoint() -> str:
4558
f"Auto-discovered agent entrypoint: {click.style(entrypoint, fg='cyan')}"
4659
)
4760
return entrypoint
61+
62+
63+
def extract_tool_calls_names(spans: Sequence[ReadableSpan]) -> list[str]:
64+
"""Extract the tool call names from execution spans IN ORDER.
65+
66+
Args:
67+
spans: List of ReadableSpan objects from agent execution.
68+
69+
Returns:
70+
List of tool names in the order they were called.
71+
"""
72+
tool_calls_names = []
73+
74+
for span in spans:
75+
# Check for tool.name attribute first
76+
if span.attributes and (tool_name := span.attributes.get("tool.name")):
77+
tool_calls_names.append(tool_name)
78+
79+
return tool_calls_names
80+
81+
82+
def extract_tool_calls(spans: Sequence[ReadableSpan]) -> list[dict[str, Any]]:
83+
"""Extract the tool calls from execution spans with their arguments.
84+
85+
Args:
86+
spans: List of ReadableSpan objects from agent execution.
87+
88+
Returns:
89+
Dict of tool calls with their arguments.
90+
"""
91+
tool_calls = []
92+
93+
for span in spans:
94+
if span.attributes and (tool_name := span.attributes.get("tool.name")):
95+
try:
96+
input_value = span.attributes.get("input.value", "{}")
97+
# Ensure input_value is a string before parsing
98+
if isinstance(input_value, str):
99+
arguments = json.loads(input_value.replace("'", '"'))
100+
else:
101+
arguments = {}
102+
tool_calls.append({"name": tool_name, "args": arguments})
103+
except json.JSONDecodeError:
104+
# Handle case where input.value is not valid JSON
105+
tool_calls.append({"name": tool_name, "args": {}})
106+
107+
return tool_calls
108+
109+
110+
def extract_tool_calls_outputs(spans: Sequence[ReadableSpan]) -> list[dict[str, Any]]:
111+
"""Extract the outputs of the tool calls from execution spans."""
112+
tool_calls_outputs = []
113+
for span in spans:
114+
if span.attributes and (tool_name := span.attributes.get("tool.name")):
115+
tool_calls_outputs.append(
116+
{"name": tool_name, "output": span.attributes.get("output.value", {})}
117+
)
118+
return tool_calls_outputs
119+
120+
121+
def tool_calls_order_score(
122+
actual_tool_calls_names: Sequence[str],
123+
expected_tool_calls_names: Sequence[str],
124+
strict: bool = False,
125+
) -> tuple[float, str]:
126+
"""The function calculates a score based on LCS applied to the order of the tool calls.
127+
128+
It calculates the longest common subsequence between the actual tool calls
129+
and the expected tool calls and returns the ratio of the LCS length to the number of
130+
expected calls.
131+
132+
Args:
133+
actual_tool_calls_names: List of tool names in the actual order
134+
expected_tool_calls_names: List of tool names in the expected order
135+
strict: If True, the function will return 0 if the actual calls do not match the expected calls
136+
137+
Returns:
138+
tuple[float, str]: Ratio of the LCS length to the number of expected, and the LCS string
139+
"""
140+
justification_template = f"Expected tool calls: {expected_tool_calls_names}\nActual tool calls: {actual_tool_calls_names}"
141+
if not strict:
142+
justification_template += "\nLongest common subsequence: {lcs}"
143+
if expected_tool_calls_names == actual_tool_calls_names:
144+
return 1.0, justification_template.format(lcs=actual_tool_calls_names)
145+
elif (
146+
not expected_tool_calls_names
147+
or not actual_tool_calls_names
148+
or strict
149+
and actual_tool_calls_names != expected_tool_calls_names
150+
):
151+
return 0.0, justification_template.format(lcs="")
152+
153+
# Calculate LCS with full DP table for efficient reconstruction
154+
m, n = len(actual_tool_calls_names), len(expected_tool_calls_names)
155+
dp = [[0] * (n + 1) for _ in range(m + 1)]
156+
157+
# Build DP table - O(m*n)
158+
for i in range(1, m + 1):
159+
for j in range(1, n + 1):
160+
if actual_tool_calls_names[i - 1] == expected_tool_calls_names[j - 1]:
161+
dp[i][j] = dp[i - 1][j - 1] + 1
162+
else:
163+
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
164+
165+
# Reconstruct LCS - O(m+n)
166+
lcs = []
167+
i, j = m, n
168+
while i > 0 and j > 0:
169+
if actual_tool_calls_names[i - 1] == expected_tool_calls_names[j - 1]:
170+
lcs.append(actual_tool_calls_names[i - 1])
171+
i -= 1
172+
j -= 1
173+
elif dp[i - 1][j] > dp[i][j - 1]:
174+
i -= 1
175+
else:
176+
j -= 1
177+
178+
lcs.reverse() # Reverse to get correct order
179+
lcs_length = len(lcs)
180+
return lcs_length / n, justification_template.format(lcs=" ".join(lcs))
181+
182+
183+
def tool_calls_count_score(
184+
actual_tool_calls_count: Mapping[str, int],
185+
expected_tool_calls_count: Mapping[str, tuple[str, int]],
186+
strict: bool = False,
187+
) -> tuple[float, str]:
188+
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
189+
190+
It does not check the order of the tool calls!
191+
"""
192+
if not expected_tool_calls_count and not actual_tool_calls_count:
193+
return 1.0, "Both expected and actual tool calls are empty"
194+
elif not expected_tool_calls_count or not actual_tool_calls_count:
195+
return 0.0, "Either expected or actual tool calls are empty"
196+
197+
score = 0.0
198+
justifications = []
199+
for tool_name, (
200+
expected_comparator,
201+
expected_count,
202+
) in expected_tool_calls_count.items():
203+
actual_count = actual_tool_calls_count.get(tool_name, 0.0)
204+
comparator = f"__{COMPARATOR_MAPPINGS[expected_comparator]}__"
205+
to_add = float(getattr(actual_count, comparator)(expected_count))
206+
justifications.append(
207+
f"{tool_name}: Actual count: {actual_count}, Expected count: {expected_count}, Score: {to_add}"
208+
)
209+
if strict and to_add == 0.0:
210+
return 0.0, justifications[-1]
211+
score += to_add
212+
return score / len(expected_tool_calls_count), "\n".join(justifications)
213+
214+
215+
def tool_args_score(
216+
actual_tool_calls: list[dict[str, Any]],
217+
expected_tool_calls: list[dict[str, Any]],
218+
strict: bool = False,
219+
subset: bool = False,
220+
) -> float:
221+
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
222+
223+
This function does not check the order of the tool calls!
224+
225+
Arguments:
226+
actual_tool_calls (list[Dict[str, Any]]): List of actual tool calls in the format of {"name": str, "args": Dict[str, Any]}
227+
expected_tool_calls (list[Dict[str, Any]]): List of expected tool calls in the format of {"name": str, "args": Dict[str, Any]}
228+
strict (bool): If True, the function will return 0 if not all expected tool calls are matched
229+
subset (bool): If True, the function will check if the expected args are a subset of the actual args
230+
231+
Returns:
232+
float: Score based on the number of matches
233+
"""
234+
cnt = 0
235+
visited: set[int] = set()
236+
237+
for expected_tool_call in expected_tool_calls:
238+
for idx, call in enumerate(actual_tool_calls):
239+
if (
240+
call.get("name") == expected_tool_call.get("name")
241+
and idx not in visited
242+
):
243+
# Check arguments based on mode
244+
if subset:
245+
# Subset mode: safely check if all expected args exist and match
246+
args_check = ( # noqa: E731
247+
lambda k, v: k in call.get("args", {}) # noqa: B023
248+
and call.get("args", {})[k] == v # noqa: B023
249+
)
250+
validator_check = lambda k, validator: k not in call.get( # noqa: E731, B023
251+
"args", {}
252+
) or validator(call.get("args", {})[k]) # noqa: B023
253+
else:
254+
# Exact mode: direct access (may raise KeyError)
255+
args_check = lambda k, v: call.get("args", {})[k] == v # noqa: E731, B023
256+
validator_check = lambda k, validator: validator( # noqa: E731
257+
call.get("args", {})[k] # noqa: B023
258+
)
259+
260+
try:
261+
args_match = all(
262+
args_check(k, v)
263+
for k, v in expected_tool_call.get("args", {}).items()
264+
)
265+
validators_match = True
266+
if expected_tool_call.get("args_validators", {}):
267+
validators_match = all(
268+
validator_check(k, validator)
269+
for k, validator in expected_tool_call.get(
270+
"args_validators", {}
271+
).items()
272+
)
273+
except KeyError:
274+
# Only possible in exact mode when key is missing
275+
args_match = False
276+
validators_match = False
277+
if args_match and validators_match:
278+
cnt += 1
279+
visited.add(idx)
280+
break
281+
282+
return (
283+
cnt / len(expected_tool_calls)
284+
if not strict
285+
else float(cnt == len(expected_tool_calls))
286+
)
287+
288+
289+
def tool_output_score(
290+
actual_tool_calls_outputs: list[dict[str, Any]],
291+
expected_tool_calls_outputs: list[dict[str, Any]],
292+
strict: bool = False,
293+
) -> float:
294+
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
295+
296+
This function does not check the order of the tool calls!
297+
"""
298+
if not expected_tool_calls_outputs and not actual_tool_calls_outputs:
299+
return 1.0
300+
elif (
301+
not expected_tool_calls_outputs
302+
or not actual_tool_calls_outputs
303+
or strict
304+
and actual_tool_calls_outputs != expected_tool_calls_outputs
305+
):
306+
return 0.0
307+
308+
cnt = 0.0
309+
for expected_tool_call_output in expected_tool_calls_outputs:
310+
for actual_tool_call_output in actual_tool_calls_outputs:
311+
if actual_tool_call_output.get("name") == expected_tool_call_output.get(
312+
"name"
313+
):
314+
if json.loads(actual_tool_call_output.get("output", "{}")).get(
315+
"content"
316+
) == expected_tool_call_output.get("output"):
317+
cnt += 1.0
318+
elif strict:
319+
return 0.0
320+
return (
321+
cnt / len(expected_tool_calls_outputs)
322+
if not strict
323+
else float(cnt == len(expected_tool_calls_outputs))
324+
)
325+
326+
327+
def trace_to_str(agent_trace: Sequence[ReadableSpan]) -> str:
328+
"""Convert OTEL spans to a platform-style agent run history string.
329+
330+
Creates a similar structure to LangChain message processing but using OTEL spans.
331+
Only processes tool spans (spans with 'tool.name' attribute).
332+
333+
Args:
334+
agent_trace: List of ReadableSpan objects from the agent execution
335+
336+
Returns:
337+
String representation of the agent run history in platform format
338+
"""
339+
platform_history = []
340+
seen_tool_calls = set()
341+
342+
for span in agent_trace:
343+
if span.attributes and (tool_name := span.attributes.get("tool.name")):
344+
# Get span timing information
345+
start_time = span.start_time
346+
end_time = span.end_time
347+
348+
# Convert nanoseconds to datetime if needed
349+
if isinstance(start_time, int):
350+
start_timestamp = datetime.fromtimestamp(start_time / 1e9)
351+
else:
352+
start_timestamp = start_time
353+
354+
if isinstance(end_time, int):
355+
end_timestamp = datetime.fromtimestamp(end_time / 1e9)
356+
else:
357+
end_timestamp = end_time
358+
359+
timestamp_str = (
360+
start_timestamp.strftime("%Y-%m-%d %H:%M:%S") if start_timestamp else ""
361+
)
362+
363+
# Get tool call information
364+
tool_args = span.attributes.get("input.value", {})
365+
tool_result = span.attributes.get("output.value", "{}")
366+
# Attempt to extract only the content of the tool result if it is a string
367+
if isinstance(tool_result, str):
368+
try:
369+
tool_result = json.loads(tool_result.replace("'", '"'))["content"]
370+
except (json.JSONDecodeError, KeyError):
371+
tool_result = tool_result
372+
373+
span_id = (
374+
span.context.span_id
375+
if span.context
376+
else str(hash(f"{tool_name}_{timestamp_str}"))
377+
)
378+
379+
# De-duplicate tool calls based on span ID
380+
if span_id in seen_tool_calls:
381+
continue
382+
seen_tool_calls.add(span_id)
383+
384+
# Add tool selection (equivalent to AIMessage with tool_calls)
385+
platform_history.append(f"[{timestamp_str}] LLM Response:")
386+
platform_history.append(" Agent Selected 1 Tool(s):")
387+
platform_history.append("")
388+
platform_history.append(f" Tool: {tool_name}")
389+
platform_history.append(f" Arguments: {str(tool_args)}")
390+
platform_history.append("")
391+
392+
# Add tool response (equivalent to ToolMessage)
393+
end_timestamp_str = (
394+
end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
395+
if end_timestamp
396+
else timestamp_str
397+
)
398+
platform_history.append(
399+
f"[{end_timestamp_str}] Tool Call Response - {tool_name}:"
400+
)
401+
platform_history.append(f"{str(tool_result).strip()}")
402+
platform_history.append("")
403+
404+
return "\n".join(platform_history)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""UiPath evaluator implementations for agent performance evaluation."""
2+
3+
from .base_evaluator import BaseEvaluator
4+
from .exact_match_evaluator import ExactMatchEvaluator
5+
from .json_similarity_evaluator import JsonSimilarityEvaluator
6+
from .llm_as_judge_evaluator import LLMJudgeEvaluator
7+
from .llm_judge_trajectory_evaluator import LLMJudgeTrajectoryEvaluator
8+
9+
__all__ = [
10+
"BaseEvaluator",
11+
"ExactMatchEvaluator",
12+
"JsonSimilarityEvaluator",
13+
"LLMJudgeEvaluator",
14+
"LLMJudgeTrajectoryEvaluator",
15+
]

0 commit comments

Comments
 (0)