Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands_evals/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, Field
from typing_extensions import Any, Generic

from .types.evaluation import InputT, Interaction, OutputT
from .types.evaluation import EnvironmentState, InputT, Interaction, OutputT


class Case(BaseModel, Generic[InputT, OutputT]):
Expand Down Expand Up @@ -47,4 +47,5 @@ class Case(BaseModel, Generic[InputT, OutputT]):
expected_output: OutputT | None = None
expected_trajectory: list[Any] | None = None
expected_interactions: list[Interaction] | None = None
expected_environment_state: list[EnvironmentState] | None = None
metadata: dict[str, Any] | None = None
3 changes: 2 additions & 1 deletion src/strands_evals/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .coherence_evaluator import CoherenceEvaluator
from .conciseness_evaluator import ConcisenessEvaluator
from .deterministic import Contains, Equals, StartsWith, ToolCalled
from .deterministic import Contains, Equals, StartsWith, StateEquals, ToolCalled
from .evaluator import Evaluator
from .faithfulness_evaluator import FaithfulnessEvaluator
from .goal_success_rate_evaluator import GoalSuccessRateEvaluator
Expand Down Expand Up @@ -30,5 +30,6 @@
"Contains",
"Equals",
"StartsWith",
"StateEquals",
"ToolCalled",
]
2 changes: 2 additions & 0 deletions src/strands_evals/evaluators/deterministic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .environment_state import StateEquals
from .output import Contains, Equals, StartsWith
from .trajectory import ToolCalled

__all__ = [
"Contains",
"Equals",
"StartsWith",
"StateEquals",
"ToolCalled",
]
67 changes: 67 additions & 0 deletions src/strands_evals/evaluators/deterministic/environment_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing_extensions import Any

from ...types.evaluation import EnvironmentState, EvaluationData, EvaluationOutput, InputT, OutputT
from ..evaluator import Evaluator


def _find_state_by_name(states: list[EnvironmentState], name: str) -> EnvironmentState | None:
"""Find an EnvironmentState by name in a list of states."""
for state in states:
if state.name == name:
return state
return None


class StateEquals(Evaluator[InputT, OutputT]):
"""Checks if a named environment state matches an expected value."""

def __init__(self, name: str, value: Any | None = None):
super().__init__()
self.name = name
self.value = value

def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]:
if not evaluation_case.actual_environment_state:
return [
EvaluationOutput(
score=0.0,
test_pass=False,
reason=f"state '{self.name}' not found: actual_environment_state is empty or None",
)
]

actual_state = _find_state_by_name(evaluation_case.actual_environment_state, self.name)
if actual_state is None:
return [
EvaluationOutput(
score=0.0,
test_pass=False,
reason=f"state '{self.name}' not found in actual_environment_state",
)
]

if self.value is not None:
expected = self.value
elif evaluation_case.expected_environment_state:
expected_state = _find_state_by_name(evaluation_case.expected_environment_state, self.name)
if expected_state is None:
raise ValueError(
f"state '{self.name}' not found in expected_environment_state and no explicit value provided"
)
expected = expected_state.state
else:
raise ValueError(
f"no expected value for state '{self.name}': provide value param or expected_environment_state"
)

match = actual_state.state == expected
return [
EvaluationOutput(
score=1.0 if match else 0.0,
test_pass=match,
reason=f"state '{self.name}' {'matches' if match else 'does not match'} expected value",
)
]

async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]:
return self.evaluate(evaluation_case)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def compose_test_prompt(
include_inputs: bool,
uses_trajectory: bool = False,
trajectory_description: dict | None = None,
uses_environment_state: bool = False,
) -> str:
"""
Compose the prompt for a test case evaluation.
Expand All @@ -17,19 +18,21 @@ def compose_test_prompt(
include_inputs: Whether to include the input in the prompt
uses_trajectory: Whether this is a trajectory-based evaluation
trajectory_description: A dictionary describing the type of trajectory expected for this evaluation.
uses_environment_state: Whether this is an environment-state-based evaluation

Returns:
str: The formatted evaluation prompt

Raises:
Exception: If actual_output is missing for non-trajectory evaluations
Exception: If actual_output is missing for output-only evaluations
Exception: If actual_trajectory is missing for trajectory evaluations
Exception: If actual_environment_state is missing for environment state evaluations
"""
evaluation_prompt = "Evaluate this singular test case. THE FINAL SCORE MUST BE A DECIMAL BETWEEN 0.0 AND 1.0 (NOT 0 to 10 OR 0 to 100). \n"
if include_inputs:
evaluation_prompt += f"<Input>{evaluation_case.input}</Input>\n"

if uses_trajectory: # trajectory evaluations don't require actual_output
if uses_trajectory or uses_environment_state: # these evaluations don't require actual_output
if evaluation_case.actual_output:
evaluation_prompt += f"<Output>{evaluation_case.actual_output}</Output>\n"
else:
Expand All @@ -53,6 +56,18 @@ def compose_test_prompt(
if trajectory_description:
evaluation_prompt += f"<TrajectoryDescription>{trajectory_description}</TrajectoryDescription>\n"

if uses_environment_state:
if evaluation_case.actual_environment_state is None:
raise Exception("Please make sure the task function return a dictionary with the key 'environment_state'.")
evaluation_prompt += (
f"<ActualEnvironmentState>{evaluation_case.actual_environment_state}</ActualEnvironmentState>\n"
)

if evaluation_case.expected_environment_state:
evaluation_prompt += (
f"<ExpectedEnvironmentState>{evaluation_case.expected_environment_state}</ExpectedEnvironmentState>\n"
)

evaluation_prompt += f"<Rubric>{rubric}</Rubric>"

return evaluation_prompt
7 changes: 6 additions & 1 deletion src/strands_evals/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import Any, Generic

from .case import Case
from .evaluators.deterministic import Contains, Equals, StartsWith, ToolCalled
from .evaluators.deterministic import Contains, Equals, StartsWith, StateEquals, ToolCalled
from .evaluators.evaluator import Evaluator
from .evaluators.interactions_evaluator import InteractionsEvaluator
from .evaluators.output_evaluator import OutputEvaluator
Expand Down Expand Up @@ -202,13 +202,15 @@ def _run_task(
expected_output=case.expected_output,
expected_trajectory=case.expected_trajectory,
expected_interactions=case.expected_interactions,
expected_environment_state=case.expected_environment_state,
metadata=case.metadata,
)
task_output = task(case)
if isinstance(task_output, dict): # could be evaluating the trajectory as well
evaluation_context.actual_output = task_output.get("output")
evaluation_context.actual_trajectory = task_output.get("trajectory")
evaluation_context.actual_interactions = task_output.get("interactions")
evaluation_context.actual_environment_state = task_output.get("environment_state")
new_input = task_output.get("input", None) # allows the user to update the input in the task function
if new_input is not None:
evaluation_context.input = new_input
Expand Down Expand Up @@ -238,6 +240,7 @@ async def _run_task_async(
expected_output=case.expected_output,
expected_trajectory=case.expected_trajectory,
expected_interactions=case.expected_interactions,
expected_environment_state=case.expected_environment_state,
metadata=case.metadata,
)

Expand All @@ -252,6 +255,7 @@ async def _run_task_async(
evaluation_context.actual_output = task_output.get("output")
evaluation_context.actual_trajectory = task_output.get("trajectory")
evaluation_context.actual_interactions = task_output.get("interactions")
evaluation_context.actual_environment_state = task_output.get("environment_state")
# allows the user to update the input in the task function
new_input = task_output.get("input", None)
if new_input is not None:
Expand Down Expand Up @@ -798,6 +802,7 @@ def from_dict(cls, data: dict, custom_evaluators: list[type[Evaluator]] | None =
"Equals": Equals,
"Contains": Contains,
"StartsWith": StartsWith,
"StateEquals": StateEquals,
"ToolCalled": ToolCalled,
}
all_evaluators: dict[str, type[Evaluator]] = {
Expand Down
3 changes: 2 additions & 1 deletion src/strands_evals/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .evaluation import EvaluationData, EvaluationOutput, InputT, Interaction, OutputT, TaskOutput
from .evaluation import EnvironmentState, EvaluationData, EvaluationOutput, InputT, Interaction, OutputT, TaskOutput
from .simulation import ActorProfile, ActorResponse

__all__ = [
"EnvironmentState",
"Interaction",
"TaskOutput",
"EvaluationData",
Expand Down
15 changes: 15 additions & 0 deletions src/strands_evals/types/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ class Interaction(TypedDict, total=False):
messages: list


class EnvironmentState(BaseModel):
"""A named piece of environment state captured after task execution.

Attributes:
name: Identifier for this state (e.g., "test_results", "file_system")
state: The captured state data
"""

name: str
state: Any


class TaskOutput(TypedDict, total=False):
"""
Structured output format for task functions that return complex results.
Expand All @@ -59,6 +71,7 @@ class TaskOutput(TypedDict, total=False):
trajectory: Union[list[Any], Session, None]
interactions: list[Interaction]
input: Any
environment_state: list[EnvironmentState]


class EvaluationData(BaseModel, Generic[InputT, OutputT]):
Expand Down Expand Up @@ -86,6 +99,8 @@ class EvaluationData(BaseModel, Generic[InputT, OutputT]):
metadata: dict[str, Any] | None = None
actual_interactions: list[Interaction] | None = None
expected_interactions: list[Interaction] | None = None
actual_environment_state: list[EnvironmentState] | None = None
expected_environment_state: list[EnvironmentState] | None = None


class EvaluationOutput(BaseModel):
Expand Down
Loading