Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands_evals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import evaluators, extractors, generators, simulation, telemetry, types
from . import evaluators, extractors, generators, providers, simulation, telemetry, types
from .case import Case
from .experiment import Experiment
from .simulation import ActorSimulator, UserSimulator
Expand All @@ -9,6 +9,7 @@
"Case",
"evaluators",
"extractors",
"providers",
"types",
"generators",
"simulation",
Expand Down
19 changes: 19 additions & 0 deletions src/strands_evals/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .exceptions import (
ProviderError,
SessionNotFoundError,
TraceNotFoundError,
TraceProviderError,
)
from .trace_provider import (
SessionFilter,
TraceProvider,
)

__all__ = [
"ProviderError",
"SessionFilter",
"SessionNotFoundError",
"TraceNotFoundError",
"TraceProvider",
"TraceProviderError",
]
25 changes: 25 additions & 0 deletions src/strands_evals/providers/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Exceptions for trace providers."""


class TraceProviderError(Exception):
"""Base exception for trace provider errors."""

pass


class SessionNotFoundError(TraceProviderError):
"""No traces found for the given session ID."""

pass


class TraceNotFoundError(TraceProviderError):
"""Trace with the given ID not found."""

pass


class ProviderError(TraceProviderError):
"""Provider is unreachable or returned an error."""

pass
98 changes: 98 additions & 0 deletions src/strands_evals/providers/trace_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""TraceProvider interface for retrieving agent trace data from observability backends."""

from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any

from ..types.evaluation import TaskOutput


@dataclass
class SessionFilter:
"""Filter criteria for discovering sessions.

Universal fields are defined here. Provider-specific parameters
go in `additional_fields`.
"""

start_time: datetime | None = None
end_time: datetime | None = None
limit: int | None = None
additional_fields: dict[str, Any] = field(default_factory=dict)


class TraceProvider(ABC):
"""Retrieves agent trace data from observability backends for evaluation.

Implementations handle authentication, pagination, and conversion from
provider-native formats to the types the evals system consumes.
"""

@abstractmethod
def get_evaluation_data(self, session_id: str) -> TaskOutput:
"""Retrieve all data needed to evaluate a session.

This is the primary access pattern — given a session ID, fetch all
traces, extract the agent output and trajectory, and return them
in a format ready for evaluation.

Args:
session_id: The session identifier (maps to Strands session_id)

Returns:
TaskOutput with 'output' (final agent response) and
'trajectory' (Session containing all traces/spans)

Raises:
SessionNotFoundError: If no traces found for session_id
ProviderError: If the provider is unreachable or returns an error
"""
...

def list_sessions(
self,
session_filter: SessionFilter | None = None,
) -> Iterator[str]:
"""Discover session IDs matching filter criteria.

Returns session IDs that can be fed to get_evaluation_data().
Not abstract — providers override to enable session discovery.

Args:
session_filter: Optional filter. If None, provider-specific defaults apply.

Yields:
Session ID strings

Raises:
NotImplementedError: If the provider does not support session discovery
ProviderError: If the provider is unreachable or returns an error
"""
raise NotImplementedError(
"This provider does not support session discovery. "
"Use get_evaluation_data() with a known session_id instead."
)

def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput:
"""Fetch a single trace and return its evaluation data.

Useful when someone has a trace_id (e.g., from a Langfuse link) but
not a session_id, or for single-shot agent runs without sessions.
Not abstract — providers override to enable trace-level retrieval.

Args:
trace_id: The unique trace identifier

Returns:
TaskOutput with 'output' and 'trajectory' for the single trace

Raises:
NotImplementedError: If the provider does not support trace-level retrieval
TraceNotFoundError: If trace doesn't exist
ProviderError: If the provider is unreachable or returns an error
"""
raise NotImplementedError(
"This provider does not support trace-level retrieval. Use get_evaluation_data() with a session_id instead."
)
6 changes: 3 additions & 3 deletions src/strands_evals/types/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class Interaction(TypedDict, total=False):
"""
Represents a single interaction in a multi-agent or multi-step system.
"""Represents a single interaction in a multi-agent or multi-step system.


Used to capture the communication flow and dependencies between different
components (agents, tools, or processing nodes) during task execution.
Expand Down Expand Up @@ -56,7 +56,7 @@ class TaskOutput(TypedDict, total=False):
"""

output: Any
trajectory: list[Any]
trajectory: Union[list[Any], Session, None]
interactions: list[Interaction]
input: Any

Expand Down
Empty file.
186 changes: 186 additions & 0 deletions tests/strands_evals/providers/test_trace_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Tests for TraceProvider ABC, SessionFilter, and exception hierarchy."""

from collections.abc import Iterator
from datetime import datetime, timezone

import pytest

from strands_evals.providers.exceptions import (
ProviderError,
SessionNotFoundError,
TraceNotFoundError,
TraceProviderError,
)
from strands_evals.providers.trace_provider import (
SessionFilter,
TraceProvider,
)
from strands_evals.types.evaluation import TaskOutput
from strands_evals.types.trace import Session, Trace


class ConcreteProvider(TraceProvider):
"""Minimal concrete implementation for testing the ABC."""

def __init__(self, session: Session | None = None):
self._session = session

def get_evaluation_data(self, session_id: str) -> TaskOutput:
if self._session is None:
raise SessionNotFoundError(f"No session found: {session_id}")
return TaskOutput(
output="test response",
trajectory=self._session,
)


class FullProvider(TraceProvider):
"""Provider that overrides all optional methods."""

def __init__(self, sessions: dict[str, Session] | None = None, session_ids: list[str] | None = None):
self._sessions = sessions or {}
self._session_ids = session_ids or []

def get_evaluation_data(self, session_id: str) -> TaskOutput:
if session_id not in self._sessions:
raise SessionNotFoundError(f"No session found: {session_id}")
return TaskOutput(
output="test response",
trajectory=self._sessions[session_id],
)

def list_sessions(self, session_filter: SessionFilter | None = None) -> Iterator[str]:
yield from self._session_ids

def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput:
for session in self._sessions.values():
for trace in session.traces:
if trace.trace_id == trace_id:
return TaskOutput(
output="test response",
trajectory=session,
)
raise TraceNotFoundError(f"No trace found: {trace_id}")


# --- SessionFilter tests ---


class TestSessionFilter:
def test_defaults(self):
f = SessionFilter()
assert f.start_time is None
assert f.end_time is None
assert f.limit is None
assert f.additional_fields == {}

def test_with_all_fields(self):
start = datetime(2025, 1, 1, tzinfo=timezone.utc)
end = datetime(2025, 1, 31, tzinfo=timezone.utc)
f = SessionFilter(
start_time=start,
end_time=end,
limit=50,
additional_fields={"environment": "production"},
)
assert f.start_time == start
assert f.end_time == end
assert f.limit == 50
assert f.additional_fields == {"environment": "production"}

def test_additional_fields_default_is_independent(self):
"""Each instance gets its own dict (no shared mutable default)."""
f1 = SessionFilter()
f2 = SessionFilter()
f1.additional_fields["key"] = "value"
assert "key" not in f2.additional_fields


# --- Exception hierarchy tests ---


class TestExceptionHierarchy:
def test_trace_provider_error_is_exception(self):
assert issubclass(TraceProviderError, Exception)

def test_session_not_found_is_trace_provider_error(self):
assert issubclass(SessionNotFoundError, TraceProviderError)

def test_trace_not_found_is_trace_provider_error(self):
assert issubclass(TraceNotFoundError, TraceProviderError)

def test_provider_error_is_trace_provider_error(self):
assert issubclass(ProviderError, TraceProviderError)

def test_exceptions_carry_message(self):
err = SessionNotFoundError("session-123 not found")
assert "session-123 not found" in str(err)

def test_catching_base_catches_all(self):
"""All provider exceptions can be caught with TraceProviderError."""
for exc_class in (SessionNotFoundError, TraceNotFoundError, ProviderError):
with pytest.raises(TraceProviderError):
raise exc_class("test")


# --- TraceProvider ABC tests ---


class TestTraceProviderABC:
def test_cannot_instantiate_without_get_evaluation_data(self):
with pytest.raises(TypeError):
TraceProvider() # type: ignore[abstract]

def test_concrete_provider_instantiates(self):
provider = ConcreteProvider()
assert isinstance(provider, TraceProvider)

def test_get_evaluation_data_returns_task_output(self):
session = Session(session_id="s1", traces=[])
provider = ConcreteProvider(session=session)
result = provider.get_evaluation_data("s1")
assert result["output"] == "test response"
assert result["trajectory"] == session

def test_get_evaluation_data_raises_session_not_found(self):
provider = ConcreteProvider(session=None)
with pytest.raises(SessionNotFoundError, match="No session found"):
provider.get_evaluation_data("missing")

def test_list_sessions_default_raises_not_implemented(self):
provider = ConcreteProvider()
with pytest.raises(NotImplementedError, match="does not support session discovery"):
list(provider.list_sessions())

def test_get_evaluation_data_by_trace_id_default_raises_not_implemented(self):
provider = ConcreteProvider()
with pytest.raises(NotImplementedError, match="does not support trace-level retrieval"):
provider.get_evaluation_data_by_trace_id("trace-123")


class TestFullProvider:
def test_list_sessions_yields_ids(self):
provider = FullProvider(session_ids=["s1", "s2", "s3"])
result = list(provider.list_sessions())
assert result == ["s1", "s2", "s3"]

def test_list_sessions_with_filter(self):
provider = FullProvider(session_ids=["s1"])
f = SessionFilter(limit=10)
result = list(provider.list_sessions(session_filter=f))
assert result == ["s1"]

def test_get_evaluation_data_by_trace_id(self):
session = Session(
session_id="s1",
traces=[Trace(trace_id="t1", session_id="s1", spans=[])],
)
provider = FullProvider(sessions={"s1": session})
result = provider.get_evaluation_data_by_trace_id("t1")
assert result["output"] == "test response"
assert result["trajectory"] == session

def test_get_evaluation_data_by_trace_id_not_found(self):
provider = FullProvider(sessions={})
with pytest.raises(TraceNotFoundError, match="No trace found"):
provider.get_evaluation_data_by_trace_id("missing")