diff --git a/src/strands_evals/case.py b/src/strands_evals/case.py index 34dd82b6..aa438604 100644 --- a/src/strands_evals/case.py +++ b/src/strands_evals/case.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from typing_extensions import Any, Generic -from .types.evaluation import InputT, Interaction, OutputT +from .types.evaluation import EnvironmentState, InputT, Interaction, OutputT class Case(BaseModel, Generic[InputT, OutputT]): @@ -47,4 +47,5 @@ class Case(BaseModel, Generic[InputT, OutputT]): expected_output: OutputT | None = None expected_trajectory: list[Any] | None = None expected_interactions: list[Interaction] | None = None + expected_environment_state: list[EnvironmentState] | None = None metadata: dict[str, Any] | None = None diff --git a/src/strands_evals/evaluators/__init__.py b/src/strands_evals/evaluators/__init__.py index f4479fc8..73411ad5 100644 --- a/src/strands_evals/evaluators/__init__.py +++ b/src/strands_evals/evaluators/__init__.py @@ -1,6 +1,6 @@ from .coherence_evaluator import CoherenceEvaluator from .conciseness_evaluator import ConcisenessEvaluator -from .deterministic import Contains, Equals, StartsWith, ToolCalled +from .deterministic import Contains, Equals, StartsWith, StateEquals, ToolCalled from .evaluator import Evaluator from .faithfulness_evaluator import FaithfulnessEvaluator from .goal_success_rate_evaluator import GoalSuccessRateEvaluator @@ -30,5 +30,6 @@ "Contains", "Equals", "StartsWith", + "StateEquals", "ToolCalled", ] diff --git a/src/strands_evals/evaluators/deterministic/__init__.py b/src/strands_evals/evaluators/deterministic/__init__.py index 7a92a91e..66cba320 100644 --- a/src/strands_evals/evaluators/deterministic/__init__.py +++ b/src/strands_evals/evaluators/deterministic/__init__.py @@ -1,3 +1,4 @@ +from .environment_state import StateEquals from .output import Contains, Equals, StartsWith from .trajectory import ToolCalled @@ -5,5 +6,6 @@ "Contains", "Equals", "StartsWith", + "StateEquals", "ToolCalled", ] diff --git a/src/strands_evals/evaluators/deterministic/environment_state.py b/src/strands_evals/evaluators/deterministic/environment_state.py new file mode 100644 index 00000000..23673906 --- /dev/null +++ b/src/strands_evals/evaluators/deterministic/environment_state.py @@ -0,0 +1,67 @@ +from typing_extensions import Any + +from ...types.evaluation import EnvironmentState, EvaluationData, EvaluationOutput, InputT, OutputT +from ..evaluator import Evaluator + + +def _find_state_by_name(states: list[EnvironmentState], name: str) -> EnvironmentState | None: + """Find an EnvironmentState by name in a list of states.""" + for state in states: + if state.name == name: + return state + return None + + +class StateEquals(Evaluator[InputT, OutputT]): + """Checks if a named environment state matches an expected value.""" + + def __init__(self, name: str, value: Any | None = None): + super().__init__() + self.name = name + self.value = value + + def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + if not evaluation_case.actual_environment_state: + return [ + EvaluationOutput( + score=0.0, + test_pass=False, + reason=f"state '{self.name}' not found: actual_environment_state is empty or None", + ) + ] + + actual_state = _find_state_by_name(evaluation_case.actual_environment_state, self.name) + if actual_state is None: + return [ + EvaluationOutput( + score=0.0, + test_pass=False, + reason=f"state '{self.name}' not found in actual_environment_state", + ) + ] + + if self.value is not None: + expected = self.value + elif evaluation_case.expected_environment_state: + expected_state = _find_state_by_name(evaluation_case.expected_environment_state, self.name) + if expected_state is None: + raise ValueError( + f"state '{self.name}' not found in expected_environment_state and no explicit value provided" + ) + expected = expected_state.state + else: + raise ValueError( + f"no expected value for state '{self.name}': provide value param or expected_environment_state" + ) + + match = actual_state.state == expected + return [ + EvaluationOutput( + score=1.0 if match else 0.0, + test_pass=match, + reason=f"state '{self.name}' {'matches' if match else 'does not match'} expected value", + ) + ] + + async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + return self.evaluate(evaluation_case) diff --git a/src/strands_evals/evaluators/prompt_templates/case_prompt_template.py b/src/strands_evals/evaluators/prompt_templates/case_prompt_template.py index d273bb26..62a23bda 100644 --- a/src/strands_evals/evaluators/prompt_templates/case_prompt_template.py +++ b/src/strands_evals/evaluators/prompt_templates/case_prompt_template.py @@ -7,6 +7,7 @@ def compose_test_prompt( include_inputs: bool, uses_trajectory: bool = False, trajectory_description: dict | None = None, + uses_environment_state: bool = False, ) -> str: """ Compose the prompt for a test case evaluation. @@ -17,19 +18,21 @@ def compose_test_prompt( include_inputs: Whether to include the input in the prompt uses_trajectory: Whether this is a trajectory-based evaluation trajectory_description: A dictionary describing the type of trajectory expected for this evaluation. + uses_environment_state: Whether this is an environment-state-based evaluation Returns: str: The formatted evaluation prompt Raises: - Exception: If actual_output is missing for non-trajectory evaluations + Exception: If actual_output is missing for output-only evaluations Exception: If actual_trajectory is missing for trajectory evaluations + Exception: If actual_environment_state is missing for environment state evaluations """ evaluation_prompt = "Evaluate this singular test case. THE FINAL SCORE MUST BE A DECIMAL BETWEEN 0.0 AND 1.0 (NOT 0 to 10 OR 0 to 100). \n" if include_inputs: evaluation_prompt += f"{evaluation_case.input}\n" - if uses_trajectory: # trajectory evaluations don't require actual_output + if uses_trajectory or uses_environment_state: # these evaluations don't require actual_output if evaluation_case.actual_output: evaluation_prompt += f"{evaluation_case.actual_output}\n" else: @@ -53,6 +56,18 @@ def compose_test_prompt( if trajectory_description: evaluation_prompt += f"{trajectory_description}\n" + if uses_environment_state: + if evaluation_case.actual_environment_state is None: + raise Exception("Please make sure the task function return a dictionary with the key 'environment_state'.") + evaluation_prompt += ( + f"{evaluation_case.actual_environment_state}\n" + ) + + if evaluation_case.expected_environment_state: + evaluation_prompt += ( + f"{evaluation_case.expected_environment_state}\n" + ) + evaluation_prompt += f"{rubric}" return evaluation_prompt diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index 874437ad..237175f7 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -17,7 +17,7 @@ from typing_extensions import Any, Generic from .case import Case -from .evaluators.deterministic import Contains, Equals, StartsWith, ToolCalled +from .evaluators.deterministic import Contains, Equals, StartsWith, StateEquals, ToolCalled from .evaluators.evaluator import Evaluator from .evaluators.interactions_evaluator import InteractionsEvaluator from .evaluators.output_evaluator import OutputEvaluator @@ -202,6 +202,7 @@ def _run_task( expected_output=case.expected_output, expected_trajectory=case.expected_trajectory, expected_interactions=case.expected_interactions, + expected_environment_state=case.expected_environment_state, metadata=case.metadata, ) task_output = task(case) @@ -209,6 +210,7 @@ def _run_task( evaluation_context.actual_output = task_output.get("output") evaluation_context.actual_trajectory = task_output.get("trajectory") evaluation_context.actual_interactions = task_output.get("interactions") + evaluation_context.actual_environment_state = task_output.get("environment_state") new_input = task_output.get("input", None) # allows the user to update the input in the task function if new_input is not None: evaluation_context.input = new_input @@ -238,6 +240,7 @@ async def _run_task_async( expected_output=case.expected_output, expected_trajectory=case.expected_trajectory, expected_interactions=case.expected_interactions, + expected_environment_state=case.expected_environment_state, metadata=case.metadata, ) @@ -252,6 +255,7 @@ async def _run_task_async( evaluation_context.actual_output = task_output.get("output") evaluation_context.actual_trajectory = task_output.get("trajectory") evaluation_context.actual_interactions = task_output.get("interactions") + evaluation_context.actual_environment_state = task_output.get("environment_state") # allows the user to update the input in the task function new_input = task_output.get("input", None) if new_input is not None: @@ -798,6 +802,7 @@ def from_dict(cls, data: dict, custom_evaluators: list[type[Evaluator]] | None = "Equals": Equals, "Contains": Contains, "StartsWith": StartsWith, + "StateEquals": StateEquals, "ToolCalled": ToolCalled, } all_evaluators: dict[str, type[Evaluator]] = { diff --git a/src/strands_evals/types/__init__.py b/src/strands_evals/types/__init__.py index f8c8938c..9c4d57b1 100644 --- a/src/strands_evals/types/__init__.py +++ b/src/strands_evals/types/__init__.py @@ -1,7 +1,8 @@ -from .evaluation import EvaluationData, EvaluationOutput, InputT, Interaction, OutputT, TaskOutput +from .evaluation import EnvironmentState, EvaluationData, EvaluationOutput, InputT, Interaction, OutputT, TaskOutput from .simulation import ActorProfile, ActorResponse __all__ = [ + "EnvironmentState", "Interaction", "TaskOutput", "EvaluationData", diff --git a/src/strands_evals/types/evaluation.py b/src/strands_evals/types/evaluation.py index 81889981..1c4d627a 100644 --- a/src/strands_evals/types/evaluation.py +++ b/src/strands_evals/types/evaluation.py @@ -33,6 +33,18 @@ class Interaction(TypedDict, total=False): messages: list +class EnvironmentState(BaseModel): + """A named piece of environment state captured after task execution. + + Attributes: + name: Identifier for this state (e.g., "test_results", "file_system") + state: The captured state data + """ + + name: str + state: Any + + class TaskOutput(TypedDict, total=False): """ Structured output format for task functions that return complex results. @@ -59,6 +71,7 @@ class TaskOutput(TypedDict, total=False): trajectory: Union[list[Any], Session, None] interactions: list[Interaction] input: Any + environment_state: list[EnvironmentState] class EvaluationData(BaseModel, Generic[InputT, OutputT]): @@ -86,6 +99,8 @@ class EvaluationData(BaseModel, Generic[InputT, OutputT]): metadata: dict[str, Any] | None = None actual_interactions: list[Interaction] | None = None expected_interactions: list[Interaction] | None = None + actual_environment_state: list[EnvironmentState] | None = None + expected_environment_state: list[EnvironmentState] | None = None class EvaluationOutput(BaseModel): diff --git a/tests/strands_evals/evaluators/deterministic/test_environment_state.py b/tests/strands_evals/evaluators/deterministic/test_environment_state.py new file mode 100644 index 00000000..867ceb82 --- /dev/null +++ b/tests/strands_evals/evaluators/deterministic/test_environment_state.py @@ -0,0 +1,177 @@ +import pytest + +from strands_evals.evaluators.deterministic.environment_state import StateEquals +from strands_evals.types import EnvironmentState, EvaluationData + + +class TestStateEquals: + def test_matches_expected_state_by_name(self): + evaluator = StateEquals(name="test_results") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + expected_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + results = evaluator.evaluate(data) + assert len(results) == 1 + assert results[0].score == 1.0 + assert results[0].test_pass is True + + def test_fails_when_state_differs(self): + evaluator = StateEquals(name="test_results") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 1})], + expected_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + results = evaluator.evaluate(data) + assert len(results) == 1 + assert results[0].score == 0.0 + assert results[0].test_pass is False + + def test_fails_when_named_state_not_found(self): + evaluator = StateEquals(name="test_results") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="other_state", state="something")], + expected_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + results = evaluator.evaluate(data) + assert results[0].score == 0.0 + assert results[0].test_pass is False + + def test_none_actual_environment_state(self): + evaluator = StateEquals(name="test_results") + data = EvaluationData( + input="q", + actual_environment_state=None, + expected_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + results = evaluator.evaluate(data) + assert results[0].score == 0.0 + assert results[0].test_pass is False + + def test_explicit_value_takes_precedence(self): + evaluator = StateEquals(name="test_results", value={"exit_code": 0}) + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + expected_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 99})], + ) + results = evaluator.evaluate(data) + assert results[0].test_pass is True + + def test_explicit_value_without_expected_state(self): + evaluator = StateEquals(name="test_results", value=42) + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state=42)], + ) + results = evaluator.evaluate(data) + assert results[0].test_pass is True + + def test_non_string_state_values(self): + evaluator = StateEquals(name="count") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="count", state=42)], + expected_environment_state=[EnvironmentState(name="count", state=42)], + ) + results = evaluator.evaluate(data) + assert results[0].test_pass is True + + def test_list_state_values(self): + evaluator = StateEquals(name="files") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="files", state=["a.py", "b.py"])], + expected_environment_state=[EnvironmentState(name="files", state=["a.py", "b.py"])], + ) + results = evaluator.evaluate(data) + assert results[0].test_pass is True + + def test_multiple_states_finds_correct_one(self): + evaluator = StateEquals(name="db_state") + data = EvaluationData( + input="q", + actual_environment_state=[ + EnvironmentState(name="test_results", state={"exit_code": 0}), + EnvironmentState(name="db_state", state={"rows": 5}), + ], + expected_environment_state=[ + EnvironmentState(name="db_state", state={"rows": 5}), + ], + ) + results = evaluator.evaluate(data) + assert results[0].test_pass is True + + def test_reason_on_match(self): + evaluator = StateEquals(name="test_results", value={"exit_code": 0}) + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + results = evaluator.evaluate(data) + assert "matches" in results[0].reason + + def test_reason_on_mismatch(self): + evaluator = StateEquals(name="test_results", value={"exit_code": 0}) + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 1})], + ) + results = evaluator.evaluate(data) + assert "does not match" in results[0].reason + + def test_reason_on_state_not_found(self): + evaluator = StateEquals(name="missing") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="other", state="x")], + expected_environment_state=[EnvironmentState(name="missing", state="x")], + ) + results = evaluator.evaluate(data) + assert "not found" in results[0].reason + + @pytest.mark.asyncio + async def test_evaluate_async_delegates_to_evaluate(self): + evaluator = StateEquals(name="test_results", value={"exit_code": 0}) + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + results = await evaluator.evaluate_async(data) + assert results[0].test_pass is True + + def test_raises_when_no_expected_environment_state(self): + evaluator = StateEquals(name="test_results") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="test_results", state={"exit_code": 0})], + ) + with pytest.raises(ValueError, match="no expected value"): + evaluator.evaluate(data) + + def test_raises_when_name_not_in_expected_environment_state(self): + evaluator = StateEquals(name="missing_state") + data = EvaluationData( + input="q", + actual_environment_state=[EnvironmentState(name="missing_state", state="x")], + expected_environment_state=[EnvironmentState(name="other_state", state="y")], + ) + with pytest.raises(ValueError, match="not found in expected_environment_state"): + evaluator.evaluate(data) + + def test_to_dict(self): + evaluator = StateEquals(name="test_results", value={"exit_code": 0}) + d = evaluator.to_dict() + assert d["evaluator_type"] == "StateEquals" + assert d["name"] == "test_results" + assert d["value"] == {"exit_code": 0} + + def test_to_dict_no_value(self): + evaluator = StateEquals(name="test_results") + d = evaluator.to_dict() + assert d["evaluator_type"] == "StateEquals" + assert d["name"] == "test_results" + assert "value" not in d diff --git a/tests/strands_evals/test_experiment.py b/tests/strands_evals/test_experiment.py index 3e1a0010..f1e52fc3 100644 --- a/tests/strands_evals/test_experiment.py +++ b/tests/strands_evals/test_experiment.py @@ -367,6 +367,7 @@ def test_experiment_to_dict_non_empty(mock_evaluator): "expected_output": "world", "expected_trajectory": None, "expected_interactions": None, + "expected_environment_state": None, "metadata": None, } ], @@ -397,6 +398,7 @@ def test_experiment_to_dict_OutputEvaluator_full(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": None, + "expected_environment_state": None, "metadata": None, } ], @@ -429,6 +431,7 @@ def test_experiment_to_dict_OutputEvaluator_default(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": None, + "expected_environment_state": None, "metadata": None, } ], @@ -452,6 +455,7 @@ def test_experiment_to_dict_TrajectoryEvaluator_default(): "expected_output": "world", "expected_trajectory": ["step1", "step2"], "expected_interactions": None, + "expected_environment_state": None, "metadata": None, } ], @@ -480,6 +484,7 @@ def test_experiment_to_dict_TrajectoryEvaluator_full(): "expected_output": "world", "expected_trajectory": ["step1", "step2"], "expected_interactions": None, + "expected_environment_state": None, "metadata": None, } ], @@ -511,6 +516,7 @@ def test_experiment_to_dict_InteractionsEvaluator_default(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": interactions, + "expected_environment_state": None, "metadata": None, } ], @@ -542,6 +548,7 @@ def test_experiment_to_dict_InteractionsEvaluator_full(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": interactions, + "expected_environment_state": None, "metadata": None, } ], @@ -572,6 +579,7 @@ def test_experiment_to_dict_case_dict(): "expected_output": {"field2": "world"}, "expected_trajectory": None, "expected_interactions": None, + "expected_environment_state": None, "metadata": {}, } ], @@ -598,6 +606,7 @@ def simple_echo(query): "expected_output": None, "expected_trajectory": None, "expected_interactions": None, + "expected_environment_state": None, "metadata": None, } ],