diff --git a/deepeval/integrations/pydantic_ai/__init__.py b/deepeval/integrations/pydantic_ai/__init__.py index 0aabb7db7e..a3315e263d 100644 --- a/deepeval/integrations/pydantic_ai/__init__.py +++ b/deepeval/integrations/pydantic_ai/__init__.py @@ -1,3 +1,5 @@ +from .agent import DeepEvalPydanticAIAgent as Agent from .patcher import instrument as instrument_pydantic_ai +from .otel import instrument_pydantic_ai as otel_instrument_pydantic_ai -__all__ = ["instrument_pydantic_ai"] +__all__ = ["instrument_pydantic_ai", "Agent", otel_instrument_pydantic_ai] diff --git a/deepeval/integrations/pydantic_ai/agent.py b/deepeval/integrations/pydantic_ai/agent.py new file mode 100644 index 0000000000..f53be68982 --- /dev/null +++ b/deepeval/integrations/pydantic_ai/agent.py @@ -0,0 +1,280 @@ +import inspect +from typing import Optional, List, Generic, TypeVar +from contextvars import ContextVar +from contextlib import asynccontextmanager + +from deepeval.prompt import Prompt +from deepeval.tracing.types import AgentSpan +from deepeval.tracing.tracing import Observer +from deepeval.metrics.base_metric import BaseMetric +from deepeval.tracing.context import current_span_context +from deepeval.integrations.pydantic_ai.utils import extract_tools_called + +try: + from pydantic_ai.agent import Agent + from pydantic_ai.tools import AgentDepsT + from pydantic_ai.output import OutputDataT + from deepeval.integrations.pydantic_ai.utils import create_patched_tool, update_trace_context, patch_llm_model + is_pydantic_ai_installed = True +except: + is_pydantic_ai_installed = False + +def pydantic_ai_installed(): + if not is_pydantic_ai_installed: + raise ImportError( + "Pydantic AI is not installed. Please install it with `pip install pydantic-ai`." + ) + +_IS_RUN_SYNC = ContextVar("deepeval_is_run_sync", default=False) + +class DeepEvalPydanticAIAgent(Agent[AgentDepsT, OutputDataT], Generic[AgentDepsT, OutputDataT]): + + trace_name: Optional[str] = None + trace_tags: Optional[List[str]] = None + trace_metadata: Optional[dict] = None + trace_thread_id: Optional[str] = None + trace_user_id: Optional[str] = None + trace_metric_collection: Optional[str] = None + trace_metrics: Optional[List[BaseMetric]] = None + + llm_prompt: Optional[Prompt] = None + llm_metrics: Optional[List[BaseMetric]] = None + llm_metric_collection: Optional[str] = None + + agent_metrics: Optional[List[BaseMetric]] = None + agent_metric_collection: Optional[str] = None + + def __init__( + self, + *args, + trace_name: Optional[str] = None, + trace_tags: Optional[List[str]] = None, + trace_metadata: Optional[dict] = None, + trace_thread_id: Optional[str] = None, + trace_user_id: Optional[str] = None, + trace_metric_collection: Optional[str] = None, + trace_metrics: Optional[List[BaseMetric]] = None, + llm_metric_collection: Optional[str] = None, + llm_metrics: Optional[List[BaseMetric]] = None, + llm_prompt: Optional[Prompt] = None, + agent_metric_collection: Optional[str] = None, + agent_metrics: Optional[List[BaseMetric]] = None, + **kwargs + ): + pydantic_ai_installed() + + self.trace_name = trace_name + self.trace_tags = trace_tags + self.trace_metadata = trace_metadata + self.trace_thread_id = trace_thread_id + self.trace_user_id = trace_user_id + self.trace_metric_collection = trace_metric_collection + self.trace_metrics = trace_metrics + + self.llm_metric_collection = llm_metric_collection + self.llm_metrics = llm_metrics + self.llm_prompt = llm_prompt + + self.agent_metric_collection = agent_metric_collection + self.agent_metrics = agent_metrics + + super().__init__(*args, **kwargs) + + patch_llm_model(self._model, llm_metric_collection, llm_metrics, llm_prompt) #TODO: Add dual patch guards + + + async def run( + self, + *args, + + name: Optional[str] = None, + tags: Optional[List[str]] = None, + user_id: Optional[str] = None, + metadata: Optional[dict] = None, + thread_id: Optional[str] = None, + metrics: Optional[List[BaseMetric]] = None, + metric_collection: Optional[str] = None, + + **kwargs + ): + sig = inspect.signature(super().run) + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + input = bound.arguments.get("user_prompt", None) + + agent_name = super().name if super().name is not None else "Agent" + + with Observer( + span_type="agent" if not _IS_RUN_SYNC.get() else "custom", + func_name=agent_name if not _IS_RUN_SYNC.get() else "run", + function_kwargs={"input": input}, + metrics=self.agent_metrics if not _IS_RUN_SYNC.get() else None, + metric_collection=self.agent_metric_collection if not _IS_RUN_SYNC.get() else None, + ) as observer: + result = await super().run(*args, **kwargs) + observer.result = result.output + update_trace_context( + + trace_name=name if name is not None else self.trace_name, + trace_tags=tags if tags is not None else self.trace_tags, + trace_metadata=metadata if metadata is not None else self.trace_metadata, + trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id, + trace_user_id=user_id if user_id is not None else self.trace_user_id, + trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection, + trace_metrics=metrics if metrics is not None else self.trace_metrics, + + trace_input=input, + trace_output=result.output, + ) + + agent_span: AgentSpan = current_span_context.get() + try: + agent_span.tools_called = extract_tools_called(result) + except: + pass + # TODO: available tools + # TODO: agent handoffs + + return result + + def run_sync( + self, + *args, + + name: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[dict] = None, + thread_id: Optional[str] = None, + user_id: Optional[str] = None, + metric_collection: Optional[str] = None, + metrics: Optional[List[BaseMetric]] = None, + + **kwargs + ): + sig = inspect.signature(super().run_sync) + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + input = bound.arguments.get("user_prompt", None) + + token = _IS_RUN_SYNC.set(True) + + agent_name = super().name if super().name is not None else "Agent" + + with Observer( + span_type="agent", + func_name=agent_name, + function_kwargs={"input": input}, + metrics=self.agent_metrics, + metric_collection=self.agent_metric_collection, + ) as observer: + try: + result = super().run_sync(*args, **kwargs) + finally: + _IS_RUN_SYNC.reset(token) + + observer.result = result.output + update_trace_context( + + trace_name=name if name is not None else self.trace_name, + trace_tags=tags if tags is not None else self.trace_tags, + trace_metadata=metadata if metadata is not None else self.trace_metadata, + trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id, + trace_user_id=user_id if user_id is not None else self.trace_user_id, + trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection, + trace_metrics=metrics if metrics is not None else self.trace_metrics, + + trace_input=input, + trace_output=result.output, + ) + + agent_span: AgentSpan = current_span_context.get() + try: + agent_span.tools_called = extract_tools_called(result) + except: + pass + + # TODO: available tools + # TODO: agent handoffs + + return result + + @asynccontextmanager + async def run_stream( + self, + *args, + + name: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[dict] = None, + thread_id: Optional[str] = None, + user_id: Optional[str] = None, + metric_collection: Optional[str] = None, + metrics: Optional[List[BaseMetric]] = None, + + **kwargs + ): + sig = inspect.signature(super().run_stream) + super_params = sig.parameters + super_kwargs = {k: v for k, v in kwargs.items() if k in super_params} + bound = sig.bind_partial(*args, **super_kwargs) + bound.apply_defaults() + input = bound.arguments.get("user_prompt", None) + + agent_name = super().name if super().name is not None else "Agent" + + with Observer( + span_type="agent", + func_name=agent_name, + function_kwargs={"input": input}, + metrics=self.agent_metrics, + metric_collection=self.agent_metric_collection, + ) as observer: + final_result = None + async with super().run_stream(*args, **super_kwargs) as result: + try: + yield result + finally: + try: + final_result = await result.get_output() + observer.result = final_result + except Exception: + pass + + update_trace_context( + + trace_name=name if name is not None else self.trace_name, + trace_tags=tags if tags is not None else self.trace_tags, + trace_metadata=metadata if metadata is not None else self.trace_metadata, + trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id, + trace_user_id=user_id if user_id is not None else self.trace_user_id, + trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection, + trace_metrics=metrics if metrics is not None else self.trace_metrics, + + trace_input=input, + trace_output=(final_result if final_result is not None else None), + ) + agent_span: AgentSpan = current_span_context.get() + try: + if final_result is not None: + agent_span.tools_called = extract_tools_called(final_result) + except: + pass + + def tool( + self, + *args, + metrics: Optional[List[BaseMetric]] = None, + metric_collection: Optional[str] = None, + **kwargs + ): + # Direct decoration: @agent.tool + if args and callable(args[0]): + patched_func = create_patched_tool(args[0], metrics, metric_collection) + new_args = (patched_func,) + args[1:] + return super(DeepEvalPydanticAIAgent, self).tool(*new_args, **kwargs) + # Decoration with args: @agent.tool(...) + super_tool = super(DeepEvalPydanticAIAgent, self).tool + def decorator(func): + patched_func = create_patched_tool(func, metrics, metric_collection) + return super_tool(*args, **kwargs)(patched_func) + return decorator \ No newline at end of file diff --git a/deepeval/integrations/pydantic_ai/otel.py b/deepeval/integrations/pydantic_ai/otel.py index 706c857f4b..143ea31caa 100644 --- a/deepeval/integrations/pydantic_ai/otel.py +++ b/deepeval/integrations/pydantic_ai/otel.py @@ -26,7 +26,6 @@ def is_opentelemetry_available(): OTLP_ENDPOINT = "https://otel.confident-ai.com/v1/traces" - def instrument_pydantic_ai(api_key: Optional[str] = None): with capture_tracing_integration("pydantic_ai"): is_opentelemetry_available() diff --git a/deepeval/integrations/pydantic_ai/patcher.py b/deepeval/integrations/pydantic_ai/patcher.py index f7ad88dcff..ad326ce152 100644 --- a/deepeval/integrations/pydantic_ai/patcher.py +++ b/deepeval/integrations/pydantic_ai/patcher.py @@ -1,411 +1,484 @@ -import functools -import deepeval -from deepeval.tracing.types import LlmOutput, LlmToolCall -from pydantic_ai.agent import AgentRunResult -from deepeval.tracing.context import current_trace_context -from deepeval.tracing.types import AgentSpan, LlmSpan -from deepeval.tracing.tracing import Observer -from typing import List, Callable, Optional, Any -from deepeval.test_case.llm_test_case import ToolCall -from deepeval.metrics.base_metric import BaseMetric -from deepeval.confident.api import get_confident_api_key -from deepeval.integrations.pydantic_ai.otel import instrument_pydantic_ai -from deepeval.telemetry import capture_tracing_integration -from deepeval.prompt import Prompt -import inspect -from contextvars import ContextVar - -try: - from pydantic_ai.agent import Agent - from pydantic_ai.models import Model - from pydantic_ai.messages import ( - ModelResponse, - ModelRequest, - ModelResponsePart, - TextPart, - ToolCallPart, - SystemPromptPart, - ToolReturnPart, - UserPromptPart, - ) - from pydantic_ai._run_context import RunContext - from deepeval.integrations.pydantic_ai.utils import ( - extract_tools_called_from_llm_response, - extract_tools_called, - sanitize_run_context, - ) - - pydantic_ai_installed = True -except: - pydantic_ai_installed = True - -_IN_RUN_SYNC = ContextVar("deepeval_in_run_sync", default=False) -_INSTRUMENTED = False - +# import inspect +# import functools +# import warnings +# from typing import List, Callable, Optional, Any +# from deepeval.tracing.types import LlmOutput, LlmToolCall +# from pydantic_ai.agent import AgentRunResult +# from deepeval.tracing.context import current_trace_context +# from deepeval.tracing.types import AgentSpan, LlmSpan +# from deepeval.tracing.tracing import Observer +# from deepeval.test_case.llm_test_case import ToolCall +# from deepeval.metrics.base_metric import BaseMetric +# from deepeval.confident.api import get_confident_api_key +# from deepeval.integrations.pydantic_ai.otel import instrument_pydantic_ai +# from deepeval.telemetry import capture_tracing_integration +# from deepeval.prompt import Prompt +# import deepeval +# # from contextvars import ContextVar + +# try: +# from pydantic_ai.agent import Agent +# from pydantic_ai.models import Model +# from pydantic_ai.messages import ( +# ModelResponse, +# ModelRequest, +# ModelResponsePart, +# TextPart, +# ToolCallPart, +# SystemPromptPart, +# ToolReturnPart, +# UserPromptPart, +# ) +# from pydantic_ai._run_context import RunContext +# from deepeval.integrations.pydantic_ai.utils import ( +# extract_tools_called_from_llm_response, +# extract_tools_called, +# sanitize_run_context, +# ) + +# pydantic_ai_installed = True +# except: +# pydantic_ai_installed = True + +# # _IN_RUN_SYNC = ContextVar("deepeval_in_run_sync", default=False) +# # _INSTRUMENTED = False + + +import warnings +from typing import Optional def instrument(otel: Optional[bool] = False, api_key: Optional[str] = None): - global _INSTRUMENTED - if api_key: - deepeval.login(api_key) - - api_key = get_confident_api_key() - - if not api_key: - raise ValueError("No api key provided.") - - if otel: - instrument_pydantic_ai(api_key) - else: - with capture_tracing_integration("pydantic_ai"): - if _INSTRUMENTED: - return - _patch_agent_init() - _patch_agent_tool_decorator() - _INSTRUMENTED = True - - -################### Init Patches ################### - - -def _patch_agent_init(): - original_init = Agent.__init__ - - @functools.wraps(original_init) - def wrapper( - *args, - llm_metric_collection: Optional[str] = None, - llm_metrics: Optional[List[BaseMetric]] = None, - llm_prompt: Optional[Prompt] = None, - agent_metric_collection: Optional[str] = None, - agent_metrics: Optional[List[BaseMetric]] = None, - **kwargs - ): - result = original_init(*args, **kwargs) - _patch_llm_model( - args[0]._model, llm_metric_collection, llm_metrics, llm_prompt - ) # runtime patch of the model - _patch_agent_run(args[0], agent_metric_collection, agent_metrics) - _patch_agent_run_sync(args[0], agent_metric_collection, agent_metrics) - return result - - Agent.__init__ = wrapper - - -def _patch_agent_tool_decorator(): - original_tool = Agent.tool - - @functools.wraps(original_tool) - def wrapper( - *args, - metrics: Optional[List[BaseMetric]] = None, - metric_collection: Optional[str] = None, - **kwargs - ): - # Case 1: Direct decoration - @agent.tool - if args and callable(args[0]): - patched_func = _create_patched_tool( - args[0], metrics, metric_collection - ) - new_args = (patched_func,) + args[1:] - return original_tool(*new_args, **kwargs) - - # Case 2: Decoration with arguments - @agent.tool(metrics=..., metric_collection=...) - else: - # Return a decorator function that will receive the actual function - def decorator(func): - patched_func = _create_patched_tool( - func, metrics, metric_collection - ) - return original_tool(*args, **kwargs)(patched_func) - - return decorator - - Agent.tool = wrapper - - -################### Runtime Patches ################### - - -def _patch_agent_run_sync( - agent: Agent, - agent_metric_collection: Optional[str] = None, - agent_metrics: Optional[List[BaseMetric]] = None, -): - original_run_sync = agent.run_sync - - @functools.wraps(original_run_sync) - def wrapper( - *args, - metric_collection: Optional[str] = None, - metrics: Optional[List[BaseMetric]] = None, - name: Optional[str] = None, - tags: Optional[List[str]] = None, - metadata: Optional[dict] = None, - thread_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ): - - sig = inspect.signature(original_run_sync) - bound = sig.bind_partial(*args, **kwargs) - bound.apply_defaults() - input = bound.arguments.get("user_prompt", None) - - with Observer( - span_type="agent", - func_name="Agent", - function_kwargs={"input": input}, - metrics=agent_metrics, - metric_collection=agent_metric_collection, - ) as observer: - - token = _IN_RUN_SYNC.set(True) - try: - result = original_run_sync(*args, **kwargs) - finally: - _IN_RUN_SYNC.reset(token) - - observer.update_span_properties = ( - lambda agent_span: set_agent_span_attributes(agent_span, result) - ) - observer.result = result.output - - _update_trace_context( - trace_name=name, - trace_tags=tags, - trace_metadata=metadata, - trace_thread_id=thread_id, - trace_user_id=user_id, - trace_metric_collection=metric_collection, - trace_metrics=metrics, - trace_input=input, - trace_output=result.output, - ) - - return result - - agent.run_sync = wrapper - - -def _patch_agent_run( - agent: Agent, - agent_metric_collection: Optional[str] = None, - agent_metrics: Optional[List[BaseMetric]] = None, -): - original_run = agent.run - - @functools.wraps(original_run) - async def wrapper( - *args, - metric_collection: Optional[str] = None, - metrics: Optional[List[BaseMetric]] = None, - name: Optional[str] = None, - tags: Optional[List[str]] = None, - metadata: Optional[dict] = None, - thread_id: Optional[str] = None, - user_id: Optional[str] = None, - **kwargs - ): - sig = inspect.signature(original_run) - bound = sig.bind_partial(*args, **kwargs) - bound.apply_defaults() - input = bound.arguments.get("user_prompt", None) - - in_sync = _IN_RUN_SYNC.get() - with Observer( - span_type="agent" if not in_sync else "custom", - func_name="Agent" if not in_sync else "run", - function_kwargs={"input": input}, - metrics=agent_metrics if not in_sync else None, - metric_collection=agent_metric_collection if not in_sync else None, - ) as observer: - result = await original_run(*args, **kwargs) - observer.update_span_properties = ( - lambda agent_span: set_agent_span_attributes(agent_span, result) - ) - observer.result = result.output - - _update_trace_context( - trace_name=name, - trace_tags=tags, - trace_metadata=metadata, - trace_thread_id=thread_id, - trace_user_id=user_id, - trace_metric_collection=metric_collection, - trace_metrics=metrics, - trace_input=input, - trace_output=result.output, - ) - - return result - - agent.run = wrapper - - -def _patch_llm_model( - model: Model, - llm_metric_collection: Optional[str] = None, - llm_metrics: Optional[List[BaseMetric]] = None, - llm_prompt: Optional[Prompt] = None, -): - original_func = model.request - sig = inspect.signature(original_func) - - try: - model_name = model.model_name - except Exception: - model_name = "unknown" - - @functools.wraps(original_func) - async def wrapper(*args, **kwargs): - bound = sig.bind_partial(*args, **kwargs) - bound.apply_defaults() - request = bound.arguments.get("messages", []) - - with Observer( - span_type="llm", - func_name="LLM", - observe_kwargs={"model": model_name}, - metrics=llm_metrics, - metric_collection=llm_metric_collection, - ) as observer: - result = await original_func(*args, **kwargs) - observer.update_span_properties = ( - lambda llm_span: set_llm_span_attributes( - llm_span, request, result, llm_prompt - ) - ) - observer.result = result - return result - - model.request = wrapper - - -################### Helper Functions ################### - - -def _create_patched_tool( - func: Callable, - metrics: Optional[List[BaseMetric]] = None, - metric_collection: Optional[str] = None, -): - import asyncio - - original_func = func - - is_async = asyncio.iscoroutinefunction(original_func) - - if is_async: - - @functools.wraps(original_func) - async def async_wrapper(*args, **kwargs): - sanitized_args = sanitize_run_context(args) - sanitized_kwargs = sanitize_run_context(kwargs) - with Observer( - span_type="tool", - func_name=original_func.__name__, - metrics=metrics, - metric_collection=metric_collection, - function_kwargs={"args": sanitized_args, **sanitized_kwargs}, - ) as observer: - result = await original_func(*args, **kwargs) - observer.result = result - - return result - - return async_wrapper - else: - - @functools.wraps(original_func) - def sync_wrapper(*args, **kwargs): - sanitized_args = sanitize_run_context(args) - sanitized_kwargs = sanitize_run_context(kwargs) - with Observer( - span_type="tool", - func_name=original_func.__name__, - metrics=metrics, - metric_collection=metric_collection, - function_kwargs={"args": sanitized_args, **sanitized_kwargs}, - ) as observer: - result = original_func(*args, **kwargs) - observer.result = result - - return result - - return sync_wrapper - - -def _update_trace_context( - trace_name: Optional[str] = None, - trace_tags: Optional[List[str]] = None, - trace_metadata: Optional[dict] = None, - trace_thread_id: Optional[str] = None, - trace_user_id: Optional[str] = None, - trace_metric_collection: Optional[str] = None, - trace_metrics: Optional[List[BaseMetric]] = None, - trace_input: Optional[Any] = None, - trace_output: Optional[Any] = None, -): - - current_trace = current_trace_context.get() - current_trace.name = trace_name - current_trace.tags = trace_tags - current_trace.metadata = trace_metadata - current_trace.thread_id = trace_thread_id - current_trace.user_id = trace_user_id - current_trace.metric_collection = trace_metric_collection - current_trace.metrics = trace_metrics - current_trace.input = trace_input - current_trace.output = trace_output - - -def set_llm_span_attributes( - llm_span: LlmSpan, - requests: List[ModelRequest], - result: ModelResponse, - llm_prompt: Optional[Prompt] = None, -): - llm_span.prompt = llm_prompt - - input = [] - for request in requests: - for part in request.parts: - if isinstance(part, SystemPromptPart): - input.append({"role": "System", "content": part.content}) - elif isinstance(part, UserPromptPart): - input.append({"role": "User", "content": part.content}) - elif isinstance(part, ToolCallPart): - input.append( - { - "role": "Tool Call", - "name": part.tool_name, - "content": part.args_as_json_str(), - } - ) - elif isinstance(part, ToolReturnPart): - input.append( - { - "role": "Tool Return", - "name": part.tool_name, - "content": part.model_response_str(), - } - ) - llm_span.input = input - - content = "" - tool_calls = [] - for part in result.parts: - if isinstance(part, TextPart): - content += part.content + "\n" - elif isinstance(part, ToolCallPart): - tool_calls.append( - LlmToolCall(name=part.tool_name, args=part.args_as_dict()) - ) - llm_span.output = LlmOutput( - role="Assistant", content=content, tool_calls=tool_calls + """ + DEPRECATED: This function is deprecated and will be removed in a future version. + Please deepeval.integrations.pydantic_ai.Agent to instrument instead. + """ + warnings.warn( + "The 'instrument_pydantic_ai()' function is deprecated and will be removed in a future version. " + "Please use deepeval.integrations.pydantic_ai.Agent to instrument instead. Refer to the documentation [link]", #TODO: add the link, + UserWarning, + stacklevel=2 ) - llm_span.tools_called = extract_tools_called_from_llm_response(result.parts) - - -def set_agent_span_attributes(agent_span: AgentSpan, result: AgentRunResult): - agent_span.tools_called = extract_tools_called(result) + + # Don't execute the original functionality + return + + # Original code below (commented out to prevent execution) + # global _INSTRUMENTED + # if api_key: + # deepeval.login(api_key) + # + # api_key = get_confident_api_key() + # + # if not api_key: + # raise ValueError("No api key provided.") + # + # if otel: + # instrument_pydantic_ai(api_key) + # else: + # with capture_tracing_integration("pydantic_ai"): + # if _INSTRUMENTED: + # return + # _patch_agent_init() + # _patch_agent_tool_decorator() + # _INSTRUMENTED = True + + +# ################### Init Patches ################### + + +# # def _patch_agent_init(): +# # original_init = Agent.__init__ + +# # @functools.wraps(original_init) +# # def wrapper( +# # *args, +# # llm_metric_collection: Optional[str] = None, +# # llm_metrics: Optional[List[BaseMetric]] = None, +# # llm_prompt: Optional[Prompt] = None, +# # agent_metric_collection: Optional[str] = None, +# # agent_metrics: Optional[List[BaseMetric]] = None, +# # name: Optional[str] = None, +# # tags: Optional[List[str]] = None, +# # metadata: Optional[dict] = None, +# # thread_id: Optional[str] = None, +# # user_id: Optional[str] = None, +# # metric_collection: Optional[str] = None, +# # metrics: Optional[List[BaseMetric]] = None, +# # **kwargs +# # ): +# # result = original_init(*args, **kwargs) +# # _patch_llm_model(args[0]._model, llm_metric_collection, llm_metrics, llm_prompt) # runtime patch of the model +# # _patch_agent_run( +# # agent=args[0], +# # agent_metric_collection=agent_metric_collection, +# # agent_metrics=agent_metrics, +# # init_trace_name=name, +# # init_trace_tags=tags, +# # init_trace_metadata=metadata, +# # init_trace_thread_id=thread_id, +# # init_trace_user_id=user_id, +# # init_trace_metric_collection=metric_collection, +# # init_trace_metrics=metrics, +# # ) +# # _patch_agent_run_sync( +# # agent=args[0], +# # agent_metric_collection=agent_metric_collection, +# # agent_metrics=agent_metrics, +# # init_trace_name=name, +# # init_trace_tags=tags, +# # init_trace_metadata=metadata, +# # init_trace_thread_id=thread_id, +# # init_trace_user_id=user_id, +# # init_trace_metric_collection=metric_collection, +# # init_trace_metrics=metrics, +# # ) +# # return result + +# # Agent.__init__ = wrapper + + +# # def _patch_agent_tool_decorator(): +# # original_tool = Agent.tool + +# # @functools.wraps(original_tool) +# # def wrapper( +# # *args, +# # metrics: Optional[List[BaseMetric]] = None, +# # metric_collection: Optional[str] = None, +# # **kwargs +# # ): +# # # Case 1: Direct decoration - @agent.tool +# # if args and callable(args[0]): +# # patched_func = _create_patched_tool( +# # args[0], metrics, metric_collection +# # ) +# # new_args = (patched_func,) + args[1:] +# # return original_tool(*new_args, **kwargs) + +# # # Case 2: Decoration with arguments - @agent.tool(metrics=..., metric_collection=...) +# # else: +# # # Return a decorator function that will receive the actual function +# # def decorator(func): +# # patched_func = _create_patched_tool( +# # func, metrics, metric_collection +# # ) +# # return original_tool(*args, **kwargs)(patched_func) + +# # return decorator + +# # Agent.tool = wrapper + + +# ################### Runtime Patches ################### + + +# # def _patch_agent_run_sync( +# # agent: Agent, +# # agent_metric_collection: Optional[str] = None, +# # agent_metrics: Optional[List[BaseMetric]] = None, +# # init_trace_name: Optional[str] = None, +# # init_trace_tags: Optional[List[str]] = None, +# # init_trace_metadata: Optional[dict] = None, +# # init_trace_thread_id: Optional[str] = None, +# # init_trace_user_id: Optional[str] = None, +# # init_trace_metric_collection: Optional[str] = None, +# # init_trace_metrics: Optional[List[BaseMetric]] = None, +# # ): +# # original_run_sync = agent.run_sync + +# # @functools.wraps(original_run_sync) +# # def wrapper( +# # *args, +# # metric_collection: Optional[str] = None, +# # metrics: Optional[List[BaseMetric]] = None, +# # name: Optional[str] = None, +# # tags: Optional[List[str]] = None, +# # metadata: Optional[dict] = None, +# # thread_id: Optional[str] = None, +# # user_id: Optional[str] = None, +# # **kwargs +# # ): + +# # sig = inspect.signature(original_run_sync) +# # bound = sig.bind_partial(*args, **kwargs) +# # bound.apply_defaults() +# # input = bound.arguments.get("user_prompt", None) + +# # with Observer( +# # span_type="agent", +# # func_name="Agent", +# # function_kwargs={"input": input}, +# # metrics=agent_metrics, +# # metric_collection=agent_metric_collection, +# # ) as observer: + +# # token = _IN_RUN_SYNC.set(True) +# # try: +# # result = original_run_sync(*args, **kwargs) +# # finally: +# # _IN_RUN_SYNC.reset(token) + +# # observer.update_span_properties = ( +# # lambda agent_span: set_agent_span_attributes(agent_span, result) +# # ) +# # observer.result = result.output + +# # _update_trace_context( +# # trace_name=init_trace_name if init_trace_name else name, +# # trace_tags=init_trace_tags if init_trace_tags else tags, +# # trace_metadata=init_trace_metadata if init_trace_metadata else metadata, +# # trace_thread_id=init_trace_thread_id if init_trace_thread_id else thread_id, +# # trace_user_id=init_trace_user_id if init_trace_user_id else user_id, +# # trace_metric_collection=init_trace_metric_collection if init_trace_metric_collection else metric_collection, +# # trace_metrics=init_trace_metrics if init_trace_metrics else metrics, +# # trace_input=input, +# # trace_output=result.output, +# # ) + +# # return result + +# # agent.run_sync = wrapper + + +# # def _patch_agent_run( +# # agent: Agent, +# # agent_metric_collection: Optional[str] = None, +# # agent_metrics: Optional[List[BaseMetric]] = None, +# # init_trace_name: Optional[str] = None, +# # init_trace_tags: Optional[List[str]] = None, +# # init_trace_metadata: Optional[dict] = None, +# # init_trace_thread_id: Optional[str] = None, +# # init_trace_user_id: Optional[str] = None, +# # init_trace_metric_collection: Optional[str] = None, +# # init_trace_metrics: Optional[List[BaseMetric]] = None, +# # ): +# # original_run = agent.run + +# # @functools.wraps(original_run) +# # async def wrapper( +# # *args, +# # metric_collection: Optional[str] = None, +# # metrics: Optional[List[BaseMetric]] = None, +# # name: Optional[str] = None, +# # tags: Optional[List[str]] = None, +# # metadata: Optional[dict] = None, +# # thread_id: Optional[str] = None, +# # user_id: Optional[str] = None, +# # **kwargs +# # ): +# # sig = inspect.signature(original_run) +# # bound = sig.bind_partial(*args, **kwargs) +# # bound.apply_defaults() +# # input = bound.arguments.get("user_prompt", None) + +# # in_sync = _IN_RUN_SYNC.get() +# # with Observer( +# # span_type="agent" if not in_sync else "custom", +# # func_name="Agent" if not in_sync else "run", +# # function_kwargs={"input": input}, +# # metrics=agent_metrics if not in_sync else None, +# # metric_collection=agent_metric_collection if not in_sync else None, +# # ) as observer: +# # print(args) +# # print(kwargs) +# # result = await original_run(*args, **kwargs) +# # observer.update_span_properties = ( +# # lambda agent_span: set_agent_span_attributes(agent_span, result) +# # ) +# # observer.result = result.output + +# # _update_trace_context( +# # trace_name=init_trace_name if init_trace_name else name, +# # trace_tags=init_trace_tags if init_trace_tags else tags, +# # trace_metadata=init_trace_metadata if init_trace_metadata else metadata, +# # trace_thread_id=init_trace_thread_id if init_trace_thread_id else thread_id, +# # trace_user_id=init_trace_user_id if init_trace_user_id else user_id, +# # trace_metric_collection=init_trace_metric_collection if init_trace_metric_collection else metric_collection, +# # trace_metrics=init_trace_metrics if init_trace_metrics else metrics, +# # trace_input=input, +# # trace_output=result.output, +# # ) + +# # return result + +# # agent.run = wrapper + + +# def patch_llm_model( +# model: Model, +# llm_metric_collection: Optional[str] = None, +# llm_metrics: Optional[List[BaseMetric]] = None, +# llm_prompt: Optional[Prompt] = None, +# ): +# original_func = model.request +# sig = inspect.signature(original_func) + +# try: +# model_name = model.model_name +# except Exception: +# model_name = "unknown" + +# @functools.wraps(original_func) +# async def wrapper(*args, **kwargs): +# bound = sig.bind_partial(*args, **kwargs) +# bound.apply_defaults() +# request = bound.arguments.get("messages", []) + +# with Observer( +# span_type="llm", +# func_name="LLM", +# observe_kwargs={"model": model_name}, +# metrics=llm_metrics, +# metric_collection=llm_metric_collection, +# ) as observer: +# result = await original_func(*args, **kwargs) +# observer.update_span_properties = ( +# lambda llm_span: set_llm_span_attributes( +# llm_span, request, result, llm_prompt +# ) +# ) +# observer.result = result +# return result + +# model.request = wrapper + + +# ################### Helper Functions ################### + + +# def create_patched_tool( +# func: Callable, +# metrics: Optional[List[BaseMetric]] = None, +# metric_collection: Optional[str] = None, +# ): +# import asyncio + +# original_func = func + +# is_async = asyncio.iscoroutinefunction(original_func) + +# if is_async: + +# @functools.wraps(original_func) +# async def async_wrapper(*args, **kwargs): +# sanitized_args = sanitize_run_context(args) +# sanitized_kwargs = sanitize_run_context(kwargs) +# with Observer( +# span_type="tool", +# func_name=original_func.__name__, +# metrics=metrics, +# metric_collection=metric_collection, +# function_kwargs={"args": sanitized_args, **sanitized_kwargs}, +# ) as observer: +# result = await original_func(*args, **kwargs) +# observer.result = result + +# return result + +# return async_wrapper +# else: + +# @functools.wraps(original_func) +# def sync_wrapper(*args, **kwargs): +# sanitized_args = sanitize_run_context(args) +# sanitized_kwargs = sanitize_run_context(kwargs) +# with Observer( +# span_type="tool", +# func_name=original_func.__name__, +# metrics=metrics, +# metric_collection=metric_collection, +# function_kwargs={"args": sanitized_args, **sanitized_kwargs}, +# ) as observer: +# result = original_func(*args, **kwargs) +# observer.result = result + +# return result + +# return sync_wrapper + + +# def update_trace_context( +# trace_name: Optional[str] = None, +# trace_tags: Optional[List[str]] = None, +# trace_metadata: Optional[dict] = None, +# trace_thread_id: Optional[str] = None, +# trace_user_id: Optional[str] = None, +# trace_metric_collection: Optional[str] = None, +# trace_metrics: Optional[List[BaseMetric]] = None, +# trace_input: Optional[Any] = None, +# trace_output: Optional[Any] = None, +# ): + +# current_trace = current_trace_context.get() + +# if trace_name: +# current_trace.name = trace_name +# if trace_tags: +# current_trace.tags = trace_tags +# if trace_metadata: +# current_trace.metadata = trace_metadata +# if trace_thread_id: +# current_trace.thread_id = trace_thread_id +# if trace_user_id: +# current_trace.user_id = trace_user_id +# if trace_metric_collection: +# current_trace.metric_collection = trace_metric_collection +# if trace_metrics: +# current_trace.metrics = trace_metrics +# if trace_input: +# current_trace.input = trace_input +# if trace_output: +# current_trace.output = trace_output + + + +# def set_llm_span_attributes( +# llm_span: LlmSpan, +# requests: List[ModelRequest], +# result: ModelResponse, +# llm_prompt: Optional[Prompt] = None, +# ): +# llm_span.prompt = llm_prompt + +# input = [] +# for request in requests: +# for part in request.parts: +# if isinstance(part, SystemPromptPart): +# input.append({"role": "System", "content": part.content}) +# elif isinstance(part, UserPromptPart): +# input.append({"role": "User", "content": part.content}) +# elif isinstance(part, ToolCallPart): +# input.append( +# { +# "role": "Tool Call", +# "name": part.tool_name, +# "content": part.args_as_json_str(), +# } +# ) +# elif isinstance(part, ToolReturnPart): +# input.append( +# { +# "role": "Tool Return", +# "name": part.tool_name, +# "content": part.model_response_str(), +# } +# ) +# llm_span.input = input + +# content = "" +# tool_calls = [] +# for part in result.parts: +# if isinstance(part, TextPart): +# content += part.content + "\n" +# elif isinstance(part, ToolCallPart): +# tool_calls.append( +# LlmToolCall(name=part.tool_name, args=part.args_as_dict()) +# ) +# llm_span.output = LlmOutput( +# role="Assistant", content=content, tool_calls=tool_calls +# ) +# llm_span.tools_called = extract_tools_called_from_llm_response(result.parts) + + +# def set_agent_span_attributes(agent_span: AgentSpan, result: AgentRunResult): +# agent_span.tools_called = extract_tools_called(result) diff --git a/deepeval/integrations/pydantic_ai/utils.py b/deepeval/integrations/pydantic_ai/utils.py index 4ff1bc4bc7..d7eedfd398 100644 --- a/deepeval/integrations/pydantic_ai/utils.py +++ b/deepeval/integrations/pydantic_ai/utils.py @@ -1,9 +1,20 @@ -from typing import List -from pydantic_ai.messages import ModelResponsePart +from time import perf_counter +from contextlib import asynccontextmanager +import inspect +import functools +from typing import Any, Callable, List, Optional + +from pydantic_ai.models import Model from pydantic_ai.agent import AgentRunResult from pydantic_ai._run_context import RunContext -from deepeval.test_case.llm_test_case import ToolCall +from pydantic_ai.messages import ModelRequest, ModelResponse, ModelResponsePart, SystemPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserPromptPart +from deepeval.prompt import Prompt +from deepeval.tracing.tracing import Observer +from deepeval.metrics.base_metric import BaseMetric +from deepeval.test_case.llm_test_case import ToolCall +from deepeval.tracing.context import current_trace_context, current_span_context +from deepeval.tracing.types import AgentSpan, LlmOutput, LlmSpan, LlmToolCall # llm tools called def extract_tools_called_from_llm_response( @@ -29,7 +40,6 @@ def extract_tools_called_from_llm_response( return tool_calls - # TODO: llm tools called (reposne is present next message) def extract_tools_called(result: AgentRunResult) -> List[ToolCall]: tool_calls = [] @@ -65,7 +75,6 @@ def extract_tools_called(result: AgentRunResult) -> List[ToolCall]: return tool_calls - def sanitize_run_context(value): """ Recursively replace pydantic-ai RunContext instances with ''. @@ -84,3 +93,216 @@ def sanitize_run_context(value): return {sanitize_run_context(v) for v in value} return value + +def patch_llm_model( + model: Model, + llm_metric_collection: Optional[str] = None, + llm_metrics: Optional[List[BaseMetric]] = None, + llm_prompt: Optional[Prompt] = None, +): + original_func = model.request + sig = inspect.signature(original_func) + + try: + model_name = model.model_name + except Exception: + model_name = "unknown" + + @functools.wraps(original_func) + async def wrapper(*args, **kwargs): + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + request = bound.arguments.get("messages", []) + + with Observer( + span_type="llm", + func_name="LLM", + observe_kwargs={"model": model_name}, + metrics=llm_metrics, + metric_collection=llm_metric_collection, + ) as observer: + result = await original_func(*args, **kwargs) + observer.update_span_properties = ( + lambda llm_span: set_llm_span_attributes( + llm_span, request, result, llm_prompt + ) + ) + observer.result = result + return result + + model.request = wrapper + + stream_original_func = model.request_stream + stream_sig = inspect.signature(stream_original_func) + + @asynccontextmanager + async def stream_wrapper(*args, **kwargs): + bound = stream_sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + request = bound.arguments.get("messages", []) + + with Observer( + span_type="llm", + func_name="LLM", + observe_kwargs={"model": model_name}, + metrics=llm_metrics, + metric_collection=llm_metric_collection, + ) as observer: + llm_span: LlmSpan = current_span_context.get() + async with stream_original_func(*args, **kwargs) as streamed_response: + try: + yield streamed_response + if not llm_span.token_intervals: + llm_span.token_intervals = {perf_counter(): "NA"} + else: + llm_span.token_intervals[perf_counter()] = "NA" + finally: + try: + result = streamed_response.get() + observer.update_span_properties = ( + lambda llm_span: set_llm_span_attributes( + llm_span, request, result, llm_prompt + ) + ) + observer.result = result + except Exception: + pass + + model.request_stream = stream_wrapper + +def create_patched_tool( + func: Callable, + metrics: Optional[List[BaseMetric]] = None, + metric_collection: Optional[str] = None, +): + import asyncio + + original_func = func + + is_async = asyncio.iscoroutinefunction(original_func) + + if is_async: + + @functools.wraps(original_func) + async def async_wrapper(*args, **kwargs): + sanitized_args = sanitize_run_context(args) + sanitized_kwargs = sanitize_run_context(kwargs) + with Observer( + span_type="tool", + func_name=original_func.__name__, + metrics=metrics, + metric_collection=metric_collection, + function_kwargs={"args": sanitized_args, **sanitized_kwargs}, + ) as observer: + result = await original_func(*args, **kwargs) + observer.result = result + + return result + + return async_wrapper + else: + + @functools.wraps(original_func) + def sync_wrapper(*args, **kwargs): + sanitized_args = sanitize_run_context(args) + sanitized_kwargs = sanitize_run_context(kwargs) + with Observer( + span_type="tool", + func_name=original_func.__name__, + metrics=metrics, + metric_collection=metric_collection, + function_kwargs={"args": sanitized_args, **sanitized_kwargs}, + ) as observer: + result = original_func(*args, **kwargs) + observer.result = result + + return result + + return sync_wrapper + + +def update_trace_context( + trace_name: Optional[str] = None, + trace_tags: Optional[List[str]] = None, + trace_metadata: Optional[dict] = None, + trace_thread_id: Optional[str] = None, + trace_user_id: Optional[str] = None, + trace_metric_collection: Optional[str] = None, + trace_metrics: Optional[List[BaseMetric]] = None, + trace_input: Optional[Any] = None, + trace_output: Optional[Any] = None, +): + + current_trace = current_trace_context.get() + + if trace_name: + current_trace.name = trace_name + if trace_tags: + current_trace.tags = trace_tags + if trace_metadata: + current_trace.metadata = trace_metadata + if trace_thread_id: + current_trace.thread_id = trace_thread_id + if trace_user_id: + current_trace.user_id = trace_user_id + if trace_metric_collection: + current_trace.metric_collection = trace_metric_collection + if trace_metrics: + current_trace.metrics = trace_metrics + if trace_input: + current_trace.input = trace_input + if trace_output: + current_trace.output = trace_output + + + +def set_llm_span_attributes( + llm_span: LlmSpan, + requests: List[ModelRequest], + result: ModelResponse, + llm_prompt: Optional[Prompt] = None, +): + llm_span.prompt = llm_prompt + + input = [] + for request in requests: + for part in request.parts: + if isinstance(part, SystemPromptPart): + input.append({"role": "System", "content": part.content}) + elif isinstance(part, UserPromptPart): + input.append({"role": "User", "content": part.content}) + elif isinstance(part, ToolCallPart): + input.append( + { + "role": "Tool Call", + "name": part.tool_name, + "content": part.args_as_json_str(), + } + ) + elif isinstance(part, ToolReturnPart): + input.append( + { + "role": "Tool Return", + "name": part.tool_name, + "content": part.model_response_str(), + } + ) + llm_span.input = input + + content = "" + tool_calls = [] + for part in result.parts: + if isinstance(part, TextPart): + content += part.content + "\n" + elif isinstance(part, ToolCallPart): + tool_calls.append( + LlmToolCall(name=part.tool_name, args=part.args_as_dict()) + ) + llm_span.output = LlmOutput( + role="Assistant", content=content, tool_calls=tool_calls + ) + llm_span.tools_called = extract_tools_called_from_llm_response(result.parts) + + +def set_agent_span_attributes(agent_span: AgentSpan, result: AgentRunResult): + agent_span.tools_called = extract_tools_called(result) \ No newline at end of file diff --git a/tests/test_integrations/test_pydanticai/pydantic_all_tests.py b/tests/test_integrations/test_pydanticai/pydantic_all_tests.py new file mode 100644 index 0000000000..045f59ba08 --- /dev/null +++ b/tests/test_integrations/test_pydanticai/pydantic_all_tests.py @@ -0,0 +1,43 @@ +from deepeval.prompt import Prompt +from deepeval.integrations.pydantic_ai import Agent +from deepeval.tracing import observe +import asyncio + +@observe(type="tool", metric_collection="test_collection_1") +def get_weather(city: str) -> str: + """Gets the weather for a given city.""" + return f"I don't know the weather for {city}." + +prompt = Prompt(alias="asd") +prompt.pull(version="00.00.01") + +agent = Agent( + "openai:gpt-4o-mini", + tools=[get_weather], + system_prompt="You are a helpful weather agent.", + trace_name="test_name_1", + trace_tags=["test_tag_1"], + trace_metadata={"test_metadata_1": "test_metadata_1"}, + trace_thread_id="test_thread_id_1", + trace_user_id="test_user_id_1", + trace_metric_collection="test_collection_1", + llm_metric_collection="test_collection_1", + llm_prompt=prompt, + agent_metric_collection="test_collection_1", +) + +async def execute_agent_stream(): + async with agent.run_stream("What is the weather in London?", name="test_name_2") as result: + async for chunk in result.stream_text(delta=True): + print(chunk, end="", flush=True) + final = await result.get_output() + print("\n\nFinal:", final) + +async def execute_agent_run(): + result = await agent.run("What is the weather in London?", name="test_name_4") + print(result.output) + +def execute_all(): + asyncio.run(execute_agent_stream()) + agent.run_sync("What is the weather in London?", name="test_name_3") + asyncio.run(execute_agent_run()) diff --git a/tests/test_integrations/test_pydanticai/pydantic_multi_agents.py b/tests/test_integrations/test_pydanticai/pydantic_multi_agents.py new file mode 100644 index 0000000000..45d00ea372 --- /dev/null +++ b/tests/test_integrations/test_pydanticai/pydantic_multi_agents.py @@ -0,0 +1,211 @@ +# from pydantic_ai import RunContext +# import asyncio + +# from deepeval.integrations.pydantic_ai import Agent + +# joke_selection_agent = Agent( +# 'openai:gpt-4o', +# system_prompt=( +# 'Use the `joke_factory` to generate some jokes, then choose the best. ' +# 'You must return just a single joke.' +# ), +# trace_name="joke_selection_agent", +# ) +# joke_generation_agent = Agent( +# 'openai:gpt-4o', output_type=list[str], +# ) + + +# @joke_selection_agent.tool +# async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: +# r = await joke_generation_agent.run( +# f'Please generate {count} jokes.', +# usage=ctx.usage, +# ) +# return r.output + +# async def execute_agent(): +# result = await joke_selection_agent.run('Tell me a joke.', name="joke_selection_agent_2") +# print(result.output) + +# asyncio.run(execute_agent()) +# result = joke_selection_agent.run_sync('Tell me a joke.') +# print(result.output) +#> Did you hear about the toothpaste scandal? They called it Colgate. + + +######################################################## + +# from dataclasses import dataclass +# import asyncio +# import httpx + +# from pydantic_ai import RunContext +# from deepeval.integrations.pydantic_ai import Agent + + +# @dataclass +# class ClientAndKey: +# http_client: httpx.AsyncClient +# api_key: str + + +# joke_selection_agent = Agent( +# 'openai:gpt-4o', +# deps_type=ClientAndKey, +# system_prompt=( +# 'Use the `joke_factory` tool to generate some jokes on the given subject, ' +# 'then choose the best. You must return just a single joke.' +# ), +# ) +# joke_generation_agent = Agent( +# 'openai:gpt-4o', +# deps_type=ClientAndKey, +# output_type=list[str], +# system_prompt=( +# 'Use the "get_jokes" tool to get some jokes on the given subject, ' +# 'then extract each joke into a list.' +# ), +# ) + + +# @joke_selection_agent.tool +# async def joke_factory(ctx: RunContext[ClientAndKey], count: int) -> list[str]: +# r = await joke_generation_agent.run( +# f'Please generate {count} jokes.', +# deps=ctx.deps, +# usage=ctx.usage, +# ) +# return r.output + + +# @joke_generation_agent.tool +# async def get_jokes(ctx: RunContext[ClientAndKey], count: int) -> str: +# response = await ctx.deps.http_client.get( +# 'https://example.com', +# params={'count': count}, +# headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, +# ) +# response.raise_for_status() +# return response.text + + +# async def main(): +# async with httpx.AsyncClient() as client: +# deps = ClientAndKey(client, 'foobar') +# result = await joke_selection_agent.run('Tell me a joke.', deps=deps) +# print(result.output) +# #> Did you hear about the toothpaste scandal? They called it Colgate. +# # print(result.usage()) +# #> RunUsage(input_tokens=309, output_tokens=32, requests=4, tool_calls=2) + +# asyncio.run(main()) + + + +from typing import Literal + +from pydantic import BaseModel, Field +from rich.prompt import Prompt + +from pydantic_ai import RunContext +from deepeval.integrations.pydantic_ai import Agent, instrument_pydantic_ai +from pydantic_ai.messages import ModelMessage + +instrument_pydantic_ai() + + +class FlightDetails(BaseModel): + flight_number: str + + +class Failed(BaseModel): + """Unable to find a satisfactory choice.""" + + +flight_search_agent = Agent[None, FlightDetails | Failed]( + 'openai:gpt-4o', + name="flight_search_agent", + output_type=FlightDetails | Failed, # type: ignore + system_prompt=( + 'Use the "flight_search" tool to find a flight ' + 'from the given origin to the given destination.' + ), +) + + +@flight_search_agent.tool +async def flight_search( + ctx: RunContext[None], origin: str, destination: str +) -> FlightDetails | None: + # in reality, this would call a flight search API or + # use a browser to scrape a flight search website + return FlightDetails(flight_number='AK456') + + + +async def find_flight() -> FlightDetails | None: + message_history: list[ModelMessage] | None = None + for _ in range(3): + prompt = Prompt.ask( + 'Where would you like to fly from and to?', + ) + result = await flight_search_agent.run( + prompt, + message_history=message_history, + ) + if isinstance(result.output, FlightDetails): + return result.output + else: + message_history = result.all_messages( + output_tool_return_content='Please try again.' + ) + + +class SeatPreference(BaseModel): + row: int = Field(ge=1, le=30) + seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] + + +# This agent is responsible for extracting the user's seat selection +seat_preference_agent = Agent[None, SeatPreference | Failed]( + 'openai:gpt-4o', + name="seat_preference_agent", + output_type=SeatPreference | Failed, # type: ignore + system_prompt=( + "Extract the user's seat preference. " + 'Seats A and F are window seats. ' + 'Row 1 is the front row and has extra leg room. ' + 'Rows 14, and 20 also have extra leg room. ' + ), +) + + +async def find_seat() -> SeatPreference: + message_history: list[ModelMessage] | None = None + while True: + answer = Prompt.ask('What seat would you like?') + + result = await seat_preference_agent.run( + answer, + message_history=message_history, + ) + if isinstance(result.output, SeatPreference): + return result.output + else: + print('Could not understand seat preference. Please try again.') + message_history = result.all_messages() + + +async def main(): + + opt_flight_details = await find_flight() + if opt_flight_details is not None: + print(f'Flight found: {opt_flight_details.flight_number}') + #> Flight found: AK456 + seat_preference = await find_seat() + print(f'Seat preference: {seat_preference}') + #> Seat preference: row=1 seat='A' + +# import asyncio +# asyncio.run(main()) \ No newline at end of file diff --git a/tests/test_integrations/test_pydanticai/pydantic_run_sync.py b/tests/test_integrations/test_pydanticai/pydantic_run_sync.py deleted file mode 100644 index e99b490a94..0000000000 --- a/tests/test_integrations/test_pydanticai/pydantic_run_sync.py +++ /dev/null @@ -1,21 +0,0 @@ -from pydantic_ai import Agent -from deepeval.tracing import observe -from deepeval.integrations.pydantic_ai import instrument_pydantic_ai - -instrument_pydantic_ai() - - -@observe(type="tool", metric_collection="test_collection_1") -def get_weather(city: str) -> str: - """Gets the weather for a given city.""" - return f"I don't know the weather for {city}." - - -agent = Agent( - "openai:gpt-4o-mini", - tools=[get_weather], - system_prompt="You are a helpful weather agent.", -) -result = agent.run_sync( - "What is the weather in London?", -) diff --git a/tests/test_integrations/test_pydanticai/pydanticai_app.json b/tests/test_integrations/test_pydanticai/pydanticai_app.json index 7f0c3e948b..2bf50ba2ee 100644 --- a/tests/test_integrations/test_pydanticai/pydanticai_app.json +++ b/tests/test_integrations/test_pydanticai/pydanticai_app.json @@ -3,14 +3,14 @@ { "agentHandoffs": [], "availableTools": [], - "endTime": "2025-09-16T15:17:49.883Z", + "endTime": "2025-09-19T20:23:11.180Z", "input": { "input": "What's the weather in Paris?" }, "metricCollection": "test_collection_1", - "name": "Agent", + "name": "", "output": "", - "startTime": "2025-09-16T15:17:40.291Z", + "startTime": "2025-09-19T20:23:06.368Z", "status": "SUCCESS", "toolsCalled": [ { @@ -32,12 +32,12 @@ } ], "baseSpans": [], - "endTime": "2025-09-16T15:17:49.883Z", + "endTime": "2025-09-19T20:23:11.180Z", "environment": "development", "input": "What's the weather in Paris?", "llmSpans": [ { - "endTime": "2025-09-16T15:17:49.882Z", + "endTime": "2025-09-19T20:23:11.179Z", "input": [ { "content": "What's the weather in Paris?", @@ -77,14 +77,14 @@ "alias": "asd", "version": "00.00.01" }, - "startTime": "2025-09-16T15:17:48.033Z", + "startTime": "2025-09-19T20:23:10.351Z", "status": "SUCCESS", "toolsCalled": [], "type": "llm", "uuid": "" }, { - "endTime": "2025-09-16T15:17:47.876Z", + "endTime": "2025-09-19T20:23:09.504Z", "input": [ { "content": "What's the weather in Paris?", @@ -122,7 +122,7 @@ "alias": "asd", "version": "00.00.01" }, - "startTime": "2025-09-16T15:17:44.548Z", + "startTime": "2025-09-19T20:23:08.409Z", "status": "SUCCESS", "toolsCalled": [ { @@ -137,7 +137,7 @@ "uuid": "" }, { - "endTime": "2025-09-16T15:17:44.383Z", + "endTime": "2025-09-19T20:23:07.573Z", "input": [ { "content": "What's the weather in Paris?", @@ -164,7 +164,7 @@ "alias": "asd", "version": "00.00.01" }, - "startTime": "2025-09-16T15:17:40.292Z", + "startTime": "2025-09-19T20:23:06.369Z", "status": "SUCCESS", "toolsCalled": [ { @@ -182,10 +182,10 @@ "test_metadata_1": "test_metadata_1" }, "metricCollection": "test_collection_1", - "name": "test_trace_1", + "name": "test_trace_2", "output": "", "retrieverSpans": [], - "startTime": "2025-09-16T15:17:40.291Z", + "startTime": "2025-09-19T20:23:06.368Z", "status": "SUCCESS", "tags": [ "test_tag_1" @@ -193,7 +193,7 @@ "threadId": "test_thread_id_1", "toolSpans": [ { - "endTime": "2025-09-16T15:17:48.033Z", + "endTime": "2025-09-19T20:23:10.350Z", "input": { "args": [ "" @@ -208,13 +208,13 @@ "temperature": "" }, "parentUuid": "", - "startTime": "2025-09-16T15:17:47.876Z", + "startTime": "2025-09-19T20:23:09.504Z", "status": "SUCCESS", "type": "tool", "uuid": "" }, { - "endTime": "2025-09-16T15:17:44.547Z", + "endTime": "2025-09-19T20:23:08.408Z", "input": { "args": [ "" @@ -228,12 +228,12 @@ "lng": "" }, "parentUuid": "", - "startTime": "2025-09-16T15:17:44.384Z", + "startTime": "2025-09-19T20:23:07.574Z", "status": "SUCCESS", "type": "tool", "uuid": "" } ], - "userId": "test_user_id_1", + "userId": "test_user_id_2", "uuid": "" } \ No newline at end of file diff --git a/tests/test_integrations/test_pydanticai/pydanticai_app.py b/tests/test_integrations/test_pydanticai/pydanticai_app.py index 204ac064d9..0b2d9b9808 100644 --- a/tests/test_integrations/test_pydanticai/pydanticai_app.py +++ b/tests/test_integrations/test_pydanticai/pydanticai_app.py @@ -6,16 +6,11 @@ from httpx import AsyncClient from pydantic import BaseModel from pydantic_ai import RunContext -from pydantic_ai import Agent +from deepeval.integrations.pydantic_ai import Agent from deepeval.prompt import Prompt -from deepeval.integrations.pydantic_ai import instrument_pydantic_ai - -instrument_pydantic_ai() - prompt = Prompt(alias="asd") prompt.pull(version="00.00.01") - @dataclass class Deps: client: AsyncClient @@ -29,6 +24,13 @@ class Deps: instructions="Be concise, reply with one sentence.", deps_type=Deps, retries=2, + + trace_name="test_trace_1", + trace_tags=["test_tag_1"], + trace_metadata={"test_metadata_1": "test_metadata_1"}, + trace_thread_id="test_thread_id_1", + trace_user_id="test_user_id_1", + trace_metric_collection="test_collection_1", ) @@ -79,12 +81,12 @@ async def run_agent(input_query: str): result = await weather_agent.run( input_query, deps=deps, - metric_collection="test_collection_1", - name="test_trace_1", - tags=["test_tag_1"], - metadata={"test_metadata_1": "test_metadata_1"}, - thread_id="test_thread_id_1", - user_id="test_user_id_1", + # metric_collection="test_collection_1", + name="test_trace_2", + # tags=["test_tag_1"], + # metadata={"test_metadata_1": "test_metadata_1"}, + # thread_id="test_thread_id_1", + user_id="test_user_id_2", ) return result.output