Skip to content

Commit caa62c4

Browse files
authored
Merge branch 'main' into fix/handle-circular-refs
2 parents 6b6a860 + 553e376 commit caa62c4

File tree

3 files changed

+124
-15
lines changed

3 files changed

+124
-15
lines changed

src/google/adk/evaluation/custom_metric_evaluator.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from .eval_case import ConversationScenario
2525
from .eval_case import Invocation
2626
from .eval_metrics import EvalMetric
27-
from .eval_metrics import EvalStatus
2827
from .evaluator import EvaluationResult
2928
from .evaluator import Evaluator
3029

@@ -44,12 +43,6 @@ def _get_metric_function(
4443
) from e
4544

4645

47-
def _get_eval_status(score: Optional[float], threshold: float) -> EvalStatus:
48-
if score is None:
49-
return EvalStatus.NOT_EVALUATED
50-
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
51-
52-
5346
class _CustomMetricEvaluator(Evaluator):
5447
"""Evaluator for custom metrics."""
5548

@@ -64,16 +57,20 @@ async def evaluate_invocations(
6457
expected_invocations: Optional[list[Invocation]],
6558
conversation_scenario: Optional[ConversationScenario] = None,
6659
) -> EvaluationResult:
60+
eval_metric = self._eval_metric.model_copy(deep=True)
61+
eval_metric.threshold = None
6762
if inspect.iscoroutinefunction(self._metric_function):
6863
eval_result = await self._metric_function(
69-
actual_invocations, expected_invocations, conversation_scenario
64+
eval_metric,
65+
actual_invocations,
66+
expected_invocations,
67+
conversation_scenario,
7068
)
7169
else:
7270
eval_result = self._metric_function(
73-
actual_invocations, expected_invocations, conversation_scenario
71+
eval_metric,
72+
actual_invocations,
73+
expected_invocations,
74+
conversation_scenario,
7475
)
75-
76-
eval_result.overall_eval_status = _get_eval_status(
77-
eval_result.overall_score, self._eval_metric.threshold
78-
)
7976
return eval_result

src/google/adk/evaluation/eval_metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,11 @@ class EvalMetric(EvalBaseModel):
258258
description="The name of the metric.",
259259
)
260260

261-
threshold: float = Field(
261+
threshold: Optional[float] = Field(
262+
default=None,
262263
description=(
263-
"A threshold value. Each metric decides how to interpret this"
264+
"This field will be deprecated soon. Please use `criterion` instead."
265+
" A threshold value. Each metric decides how to interpret this"
264266
" threshold."
265267
),
266268
)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from google.adk.evaluation.custom_metric_evaluator import _CustomMetricEvaluator
18+
from google.adk.evaluation.custom_metric_evaluator import _get_metric_function
19+
from google.adk.evaluation.eval_case import ConversationScenario
20+
from google.adk.evaluation.eval_case import Invocation
21+
from google.adk.evaluation.eval_metrics import EvalMetric
22+
from google.adk.evaluation.evaluator import EvaluationResult
23+
import pytest
24+
25+
26+
def my_sync_metric_function(
27+
eval_metric: EvalMetric,
28+
actual_invocations: list[Invocation],
29+
expected_invocations: list[Invocation] | None,
30+
conversation_scenario: ConversationScenario | None,
31+
) -> EvaluationResult:
32+
"""Sync metric function for testing."""
33+
return EvaluationResult(overall_score=1.0)
34+
35+
36+
async def my_async_metric_function(
37+
eval_metric: EvalMetric,
38+
actual_invocations: list[Invocation],
39+
expected_invocations: list[Invocation] | None,
40+
conversation_scenario: ConversationScenario | None,
41+
) -> EvaluationResult:
42+
"""Async metric function for testing."""
43+
return EvaluationResult(overall_score=0.5)
44+
45+
46+
@mock.patch("importlib.import_module")
47+
def test_get_metric_function_success(mock_import_module):
48+
"""Tests that _get_metric_function successfully returns a function."""
49+
mock_module = mock.MagicMock()
50+
mock_module.my_sync_metric_function = my_sync_metric_function
51+
mock_import_module.return_value = mock_module
52+
func = _get_metric_function(
53+
"test_custom_metric_evaluator.my_sync_metric_function"
54+
)
55+
assert func == my_sync_metric_function
56+
57+
58+
@mock.patch("importlib.import_module", side_effect=ImportError)
59+
def test_get_metric_function_module_not_found(mock_import_module):
60+
"""Tests that _get_metric_function raises ImportError for non-existent module."""
61+
with pytest.raises(ImportError):
62+
_get_metric_function("non_existent_module.my_sync_metric_function")
63+
64+
65+
@mock.patch("importlib.import_module")
66+
def test_get_metric_function_function_not_found(mock_import_module):
67+
"""Tests that _get_metric_function raises ImportError for non-existent function."""
68+
mock_import_module.return_value = object()
69+
with pytest.raises(ImportError):
70+
_get_metric_function(
71+
"google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.non_existent_function"
72+
)
73+
74+
75+
def test_get_metric_function_malformed_path():
76+
"""Tests that _get_metric_function raises ImportError for malformed path."""
77+
with pytest.raises(ImportError):
78+
_get_metric_function("malformed_path")
79+
80+
81+
@mock.patch(
82+
"google.adk.evaluation.custom_metric_evaluator._get_metric_function",
83+
return_value=my_sync_metric_function,
84+
)
85+
@pytest.mark.asyncio
86+
async def test_custom_metric_evaluator_sync_function(mock_get_metric_function):
87+
"""Tests that _CustomMetricEvaluator works with a sync metric function."""
88+
eval_metric = EvalMetric(metric_name="sync_metric")
89+
evaluator = _CustomMetricEvaluator(
90+
eval_metric=eval_metric,
91+
custom_function_path="google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.my_sync_metric_function",
92+
)
93+
result = await evaluator.evaluate_invocations([], None)
94+
assert result.overall_score == 1.0
95+
96+
97+
@mock.patch(
98+
"google.adk.evaluation.custom_metric_evaluator._get_metric_function",
99+
return_value=my_async_metric_function,
100+
)
101+
@pytest.mark.asyncio
102+
async def test_custom_metric_evaluator_async_function(mock_get_metric_function):
103+
"""Tests that _CustomMetricEvaluator works with an async metric function."""
104+
eval_metric = EvalMetric(metric_name="async_metric")
105+
evaluator = _CustomMetricEvaluator(
106+
eval_metric=eval_metric,
107+
custom_function_path="google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.my_async_metric_function",
108+
)
109+
result = await evaluator.evaluate_invocations([], None)
110+
assert result.overall_score == 0.5

0 commit comments

Comments
 (0)