1+ import inspect
2+ from typing import Optional , List , Generic , TypeVar
3+ from contextvars import ContextVar
4+ from contextlib import asynccontextmanager
5+
6+ from deepeval .prompt import Prompt
7+ from deepeval .tracing .types import AgentSpan
8+ from deepeval .tracing .tracing import Observer
9+ from deepeval .metrics .base_metric import BaseMetric
10+ from deepeval .tracing .context import current_span_context
11+ from deepeval .integrations .pydantic_ai .utils import extract_tools_called
12+
13+ try :
14+ from pydantic_ai .agent import Agent
15+ from pydantic_ai .tools import AgentDepsT
16+ from pydantic_ai .output import OutputDataT
17+ from deepeval .integrations .pydantic_ai .utils import create_patched_tool , update_trace_context , patch_llm_model
18+ is_pydantic_ai_installed = True
19+ except :
20+ is_pydantic_ai_installed = False
21+
22+ def pydantic_ai_installed ():
23+ if not is_pydantic_ai_installed :
24+ raise ImportError (
25+ "Pydantic AI is not installed. Please install it with `pip install pydantic-ai`."
26+ )
27+
28+ _IS_RUN_SYNC = ContextVar ("deepeval_is_run_sync" , default = False )
29+
30+ class DeepEvalPydanticAIAgent (Agent [AgentDepsT , OutputDataT ], Generic [AgentDepsT , OutputDataT ]):
31+
32+ trace_name : Optional [str ] = None
33+ trace_tags : Optional [List [str ]] = None
34+ trace_metadata : Optional [dict ] = None
35+ trace_thread_id : Optional [str ] = None
36+ trace_user_id : Optional [str ] = None
37+ trace_metric_collection : Optional [str ] = None
38+ trace_metrics : Optional [List [BaseMetric ]] = None
39+
40+ llm_prompt : Optional [Prompt ] = None
41+ llm_metrics : Optional [List [BaseMetric ]] = None
42+ llm_metric_collection : Optional [str ] = None
43+
44+ agent_metrics : Optional [List [BaseMetric ]] = None
45+ agent_metric_collection : Optional [str ] = None
46+
47+ def __init__ (
48+ self ,
49+ * args ,
50+ trace_name : Optional [str ] = None ,
51+ trace_tags : Optional [List [str ]] = None ,
52+ trace_metadata : Optional [dict ] = None ,
53+ trace_thread_id : Optional [str ] = None ,
54+ trace_user_id : Optional [str ] = None ,
55+ trace_metric_collection : Optional [str ] = None ,
56+ trace_metrics : Optional [List [BaseMetric ]] = None ,
57+ llm_metric_collection : Optional [str ] = None ,
58+ llm_metrics : Optional [List [BaseMetric ]] = None ,
59+ llm_prompt : Optional [Prompt ] = None ,
60+ agent_metric_collection : Optional [str ] = None ,
61+ agent_metrics : Optional [List [BaseMetric ]] = None ,
62+ ** kwargs
63+ ):
64+ pydantic_ai_installed ()
65+
66+ self .trace_name = trace_name
67+ self .trace_tags = trace_tags
68+ self .trace_metadata = trace_metadata
69+ self .trace_thread_id = trace_thread_id
70+ self .trace_user_id = trace_user_id
71+ self .trace_metric_collection = trace_metric_collection
72+ self .trace_metrics = trace_metrics
73+
74+ self .llm_metric_collection = llm_metric_collection
75+ self .llm_metrics = llm_metrics
76+ self .llm_prompt = llm_prompt
77+
78+ self .agent_metric_collection = agent_metric_collection
79+ self .agent_metrics = agent_metrics
80+
81+ super ().__init__ (* args , ** kwargs )
82+
83+ patch_llm_model (self ._model , llm_metric_collection , llm_metrics , llm_prompt ) #TODO: Add dual patch guards
84+
85+
86+ async def run (
87+ self ,
88+ * args ,
89+
90+ name : Optional [str ] = None ,
91+ tags : Optional [List [str ]] = None ,
92+ user_id : Optional [str ] = None ,
93+ metadata : Optional [dict ] = None ,
94+ thread_id : Optional [str ] = None ,
95+ metrics : Optional [List [BaseMetric ]] = None ,
96+ metric_collection : Optional [str ] = None ,
97+
98+ ** kwargs
99+ ):
100+ sig = inspect .signature (super ().run )
101+ bound = sig .bind_partial (* args , ** kwargs )
102+ bound .apply_defaults ()
103+ input = bound .arguments .get ("user_prompt" , None )
104+
105+ agent_name = super ().name if super ().name is not None else "Agent"
106+
107+ with Observer (
108+ span_type = "agent" if not _IS_RUN_SYNC .get () else "custom" ,
109+ func_name = agent_name if not _IS_RUN_SYNC .get () else "run" ,
110+ function_kwargs = {"input" : input },
111+ metrics = self .agent_metrics if not _IS_RUN_SYNC .get () else None ,
112+ metric_collection = self .agent_metric_collection if not _IS_RUN_SYNC .get () else None ,
113+ ) as observer :
114+ result = await super ().run (* args , ** kwargs )
115+ observer .result = result .output
116+ update_trace_context (
117+
118+ trace_name = name if name is not None else self .trace_name ,
119+ trace_tags = tags if tags is not None else self .trace_tags ,
120+ trace_metadata = metadata if metadata is not None else self .trace_metadata ,
121+ trace_thread_id = thread_id if thread_id is not None else self .trace_thread_id ,
122+ trace_user_id = user_id if user_id is not None else self .trace_user_id ,
123+ trace_metric_collection = metric_collection if metric_collection is not None else self .trace_metric_collection ,
124+ trace_metrics = metrics if metrics is not None else self .trace_metrics ,
125+
126+ trace_input = input ,
127+ trace_output = result .output ,
128+ )
129+
130+ agent_span : AgentSpan = current_span_context .get ()
131+ try :
132+ agent_span .tools_called = extract_tools_called (result )
133+ except :
134+ pass
135+ # TODO: available tools
136+ # TODO: agent handoffs
137+
138+ return result
139+
140+ def run_sync (
141+ self ,
142+ * args ,
143+
144+ name : Optional [str ] = None ,
145+ tags : Optional [List [str ]] = None ,
146+ metadata : Optional [dict ] = None ,
147+ thread_id : Optional [str ] = None ,
148+ user_id : Optional [str ] = None ,
149+ metric_collection : Optional [str ] = None ,
150+ metrics : Optional [List [BaseMetric ]] = None ,
151+
152+ ** kwargs
153+ ):
154+ sig = inspect .signature (super ().run_sync )
155+ bound = sig .bind_partial (* args , ** kwargs )
156+ bound .apply_defaults ()
157+ input = bound .arguments .get ("user_prompt" , None )
158+
159+ token = _IS_RUN_SYNC .set (True )
160+
161+ agent_name = super ().name if super ().name is not None else "Agent"
162+
163+ with Observer (
164+ span_type = "agent" ,
165+ func_name = agent_name ,
166+ function_kwargs = {"input" : input },
167+ metrics = self .agent_metrics ,
168+ metric_collection = self .agent_metric_collection ,
169+ ) as observer :
170+ try :
171+ result = super ().run_sync (* args , ** kwargs )
172+ finally :
173+ _IS_RUN_SYNC .reset (token )
174+
175+ observer .result = result .output
176+ update_trace_context (
177+
178+ trace_name = name if name is not None else self .trace_name ,
179+ trace_tags = tags if tags is not None else self .trace_tags ,
180+ trace_metadata = metadata if metadata is not None else self .trace_metadata ,
181+ trace_thread_id = thread_id if thread_id is not None else self .trace_thread_id ,
182+ trace_user_id = user_id if user_id is not None else self .trace_user_id ,
183+ trace_metric_collection = metric_collection if metric_collection is not None else self .trace_metric_collection ,
184+ trace_metrics = metrics if metrics is not None else self .trace_metrics ,
185+
186+ trace_input = input ,
187+ trace_output = result .output ,
188+ )
189+
190+ agent_span : AgentSpan = current_span_context .get ()
191+ try :
192+ agent_span .tools_called = extract_tools_called (result )
193+ except :
194+ pass
195+
196+ # TODO: available tools
197+ # TODO: agent handoffs
198+
199+ return result
200+
201+ @asynccontextmanager
202+ async def run_stream (
203+ self ,
204+ * args ,
205+
206+ name : Optional [str ] = None ,
207+ tags : Optional [List [str ]] = None ,
208+ metadata : Optional [dict ] = None ,
209+ thread_id : Optional [str ] = None ,
210+ user_id : Optional [str ] = None ,
211+ metric_collection : Optional [str ] = None ,
212+ metrics : Optional [List [BaseMetric ]] = None ,
213+
214+ ** kwargs
215+ ):
216+ sig = inspect .signature (super ().run_stream )
217+ super_params = sig .parameters
218+ super_kwargs = {k : v for k , v in kwargs .items () if k in super_params }
219+ bound = sig .bind_partial (* args , ** super_kwargs )
220+ bound .apply_defaults ()
221+ input = bound .arguments .get ("user_prompt" , None )
222+
223+ agent_name = super ().name if super ().name is not None else "Agent"
224+
225+ with Observer (
226+ span_type = "agent" ,
227+ func_name = agent_name ,
228+ function_kwargs = {"input" : input },
229+ metrics = self .agent_metrics ,
230+ metric_collection = self .agent_metric_collection ,
231+ ) as observer :
232+ final_result = None
233+ async with super ().run_stream (* args , ** super_kwargs ) as result :
234+ try :
235+ yield result
236+ finally :
237+ try :
238+ final_result = await result .get_output ()
239+ observer .result = final_result
240+ except Exception :
241+ pass
242+
243+ update_trace_context (
244+
245+ trace_name = name if name is not None else self .trace_name ,
246+ trace_tags = tags if tags is not None else self .trace_tags ,
247+ trace_metadata = metadata if metadata is not None else self .trace_metadata ,
248+ trace_thread_id = thread_id if thread_id is not None else self .trace_thread_id ,
249+ trace_user_id = user_id if user_id is not None else self .trace_user_id ,
250+ trace_metric_collection = metric_collection if metric_collection is not None else self .trace_metric_collection ,
251+ trace_metrics = metrics if metrics is not None else self .trace_metrics ,
252+
253+ trace_input = input ,
254+ trace_output = (final_result if final_result is not None else None ),
255+ )
256+ agent_span : AgentSpan = current_span_context .get ()
257+ try :
258+ if final_result is not None :
259+ agent_span .tools_called = extract_tools_called (final_result )
260+ except :
261+ pass
262+
263+ def tool (
264+ self ,
265+ * args ,
266+ metrics : Optional [List [BaseMetric ]] = None ,
267+ metric_collection : Optional [str ] = None ,
268+ ** kwargs
269+ ):
270+ # Direct decoration: @agent.tool
271+ if args and callable (args [0 ]):
272+ patched_func = create_patched_tool (args [0 ], metrics , metric_collection )
273+ new_args = (patched_func ,) + args [1 :]
274+ return super (DeepEvalPydanticAIAgent , self ).tool (* new_args , ** kwargs )
275+ # Decoration with args: @agent.tool(...)
276+ super_tool = super (DeepEvalPydanticAIAgent , self ).tool
277+ def decorator (func ):
278+ patched_func = create_patched_tool (func , metrics , metric_collection )
279+ return super_tool (* args , ** kwargs )(patched_func )
280+ return decorator
0 commit comments