11import inspect
2- from typing import Optional , List , Generic , TypeVar
2+ from typing import Optional , List , Generic , TypeVar , AsyncIterator
3+ from typing import Any
34from contextvars import ContextVar
45from contextlib import asynccontextmanager
56
1920 update_trace_context ,
2021 patch_llm_model ,
2122 )
22-
23+ from pydantic_ai .output import OutputSpec
24+ from pydantic_ai .result import AgentRunResult , StreamedRunResult
2325 is_pydantic_ai_installed = True
2426except :
2527 is_pydantic_ai_installed = False
@@ -35,10 +37,14 @@ def pydantic_ai_installed():
3537_IS_RUN_SYNC = ContextVar ("deepeval_is_run_sync" , default = False )
3638
3739
38- class DeepEvalPydanticAIAgent (
39- Agent [ AgentDepsT , OutputDataT ], Generic [ AgentDepsT , OutputDataT ]
40- ):
40+ AgentDepsT = TypeVar ( 'AgentDepsT' , default = None , contravariant = True )
41+ OutputDataT = TypeVar ( ' OutputDataT' , default = str , covariant = True )
42+ NoneType = type ( None )
4143
44+ class DeepEvalPydanticAIAgent (Generic [AgentDepsT , OutputDataT ]):
45+
46+ agent : Agent
47+
4248 trace_name : Optional [str ] = None
4349 trace_tags : Optional [List [str ]] = None
4450 trace_metadata : Optional [dict ] = None
@@ -57,6 +63,8 @@ class DeepEvalPydanticAIAgent(
5763 def __init__ (
5864 self ,
5965 * args ,
66+ output_type : OutputSpec [OutputDataT ] = str ,
67+ deps_type : type [AgentDepsT ] = NoneType ,
6068 trace_name : Optional [str ] = None ,
6169 trace_tags : Optional [List [str ]] = None ,
6270 trace_metadata : Optional [dict ] = None ,
@@ -88,10 +96,15 @@ def __init__(
8896 self .agent_metric_collection = agent_metric_collection
8997 self .agent_metrics = agent_metrics
9098
91- super ().__init__ (* args , ** kwargs )
99+ self .agent = Agent (
100+ * args ,
101+ output_type = output_type ,
102+ deps_type = deps_type ,
103+ ** kwargs
104+ )
92105
93106 patch_llm_model (
94- self ._model , llm_metric_collection , llm_metrics , llm_prompt
107+ self .agent . model , llm_metric_collection , llm_metrics , llm_prompt
95108 ) # TODO: Add dual patch guards
96109
97110 async def run (
@@ -105,13 +118,13 @@ async def run(
105118 metrics : Optional [List [BaseMetric ]] = None ,
106119 metric_collection : Optional [str ] = None ,
107120 ** kwargs
108- ):
109- sig = inspect .signature (super () .run )
121+ ) -> AgentRunResult [ OutputDataT ] :
122+ sig = inspect .signature (self . agent .run )
110123 bound = sig .bind_partial (* args , ** kwargs )
111124 bound .apply_defaults ()
112125 input = bound .arguments .get ("user_prompt" , None )
113126
114- agent_name = super (). name if super () .name is not None else "Agent"
127+ agent_name = self . agent . name if self . agent .name is not None else "Agent"
115128
116129 with Observer (
117130 span_type = "agent" if not _IS_RUN_SYNC .get () else "custom" ,
@@ -122,7 +135,11 @@ async def run(
122135 self .agent_metric_collection if not _IS_RUN_SYNC .get () else None
123136 ),
124137 ) as observer :
125- result = await super ().run (* args , ** kwargs )
138+ result = await self .agent .run (
139+ * args ,
140+ ** kwargs
141+ )
142+
126143 observer .result = result .output
127144 update_trace_context (
128145 trace_name = name if name is not None else self .trace_name ,
@@ -169,15 +186,15 @@ def run_sync(
169186 metric_collection : Optional [str ] = None ,
170187 metrics : Optional [List [BaseMetric ]] = None ,
171188 ** kwargs
172- ):
173- sig = inspect .signature (super () .run_sync )
189+ ) -> AgentRunResult [ OutputDataT ] :
190+ sig = inspect .signature (self . agent .run_sync )
174191 bound = sig .bind_partial (* args , ** kwargs )
175192 bound .apply_defaults ()
176193 input = bound .arguments .get ("user_prompt" , None )
177194
178195 token = _IS_RUN_SYNC .set (True )
179196
180- agent_name = super (). name if super () .name is not None else "Agent"
197+ agent_name = self . agent . name if self . agent .name is not None else "Agent"
181198
182199 with Observer (
183200 span_type = "agent" ,
@@ -187,7 +204,7 @@ def run_sync(
187204 metric_collection = self .agent_metric_collection ,
188205 ) as observer :
189206 try :
190- result = super () .run_sync (* args , ** kwargs )
207+ result = self . agent .run_sync (* args , ** kwargs )
191208 finally :
192209 _IS_RUN_SYNC .reset (token )
193210
@@ -239,15 +256,15 @@ async def run_stream(
239256 metric_collection : Optional [str ] = None ,
240257 metrics : Optional [List [BaseMetric ]] = None ,
241258 ** kwargs
242- ):
243- sig = inspect .signature (super () .run_stream )
259+ ) -> AsyncIterator [ StreamedRunResult [ AgentDepsT , OutputDataT ]] :
260+ sig = inspect .signature (self . agent .run_stream )
244261 super_params = sig .parameters
245262 super_kwargs = {k : v for k , v in kwargs .items () if k in super_params }
246263 bound = sig .bind_partial (* args , ** super_kwargs )
247264 bound .apply_defaults ()
248265 input = bound .arguments .get ("user_prompt" , None )
249266
250- agent_name = super (). name if super () .name is not None else "Agent"
267+ agent_name = self . agent . name if self . agent .name is not None else "Agent"
251268
252269 with Observer (
253270 span_type = "agent" ,
@@ -257,7 +274,7 @@ async def run_stream(
257274 metric_collection = self .agent_metric_collection ,
258275 ) as observer :
259276 final_result = None
260- async with super (). run_stream (* args , ** super_kwargs ) as result :
277+ async with self . agent . run_stream (* args , ** kwargs ) as result :
261278 try :
262279 yield result
263280 finally :
@@ -319,21 +336,19 @@ def tool(
319336 metrics : Optional [List [BaseMetric ]] = None ,
320337 metric_collection : Optional [str ] = None ,
321338 ** kwargs
322- ):
339+ ) -> Any :
323340 # Direct decoration: @agent.tool
324341 if args and callable (args [0 ]):
325342 patched_func = create_patched_tool (
326343 args [0 ], metrics , metric_collection
327344 )
328345 new_args = (patched_func ,) + args [1 :]
329- return super ( DeepEvalPydanticAIAgent , self ) .tool (
346+ return self . agent .tool (
330347 * new_args , ** kwargs
331348 )
332- # Decoration with args: @agent.tool(...)
333- super_tool = super (DeepEvalPydanticAIAgent , self ).tool
334349
335350 def decorator (func ):
336351 patched_func = create_patched_tool (func , metrics , metric_collection )
337- return super_tool (* args , ** kwargs )(patched_func )
352+ return self . agent . tool (* args , ** kwargs )(patched_func )
338353
339354 return decorator
0 commit comments