Skip to content

Commit dc23d34

Browse files
.
1 parent d9df2a2 commit dc23d34

1 file changed

Lines changed: 39 additions & 24 deletions

File tree

  • deepeval/integrations/pydantic_ai

deepeval/integrations/pydantic_ai/agent.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
2-
from typing import Optional, List, Generic, TypeVar
2+
from typing import Optional, List, Generic, TypeVar, AsyncIterator
3+
from typing import Any
34
from contextvars import ContextVar
45
from contextlib import asynccontextmanager
56

@@ -19,7 +20,8 @@
1920
update_trace_context,
2021
patch_llm_model,
2122
)
22-
23+
from pydantic_ai.output import OutputSpec
24+
from pydantic_ai.result import AgentRunResult, StreamedRunResult
2325
is_pydantic_ai_installed = True
2426
except:
2527
is_pydantic_ai_installed = False
@@ -35,10 +37,14 @@ def pydantic_ai_installed():
3537
_IS_RUN_SYNC = ContextVar("deepeval_is_run_sync", default=False)
3638

3739

38-
class DeepEvalPydanticAIAgent(
39-
Agent[AgentDepsT, OutputDataT], Generic[AgentDepsT, OutputDataT]
40-
):
40+
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
41+
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
42+
NoneType = type(None)
4143

44+
class DeepEvalPydanticAIAgent(Generic[AgentDepsT, OutputDataT]):
45+
46+
agent: Agent
47+
4248
trace_name: Optional[str] = None
4349
trace_tags: Optional[List[str]] = None
4450
trace_metadata: Optional[dict] = None
@@ -57,6 +63,8 @@ class DeepEvalPydanticAIAgent(
5763
def __init__(
5864
self,
5965
*args,
66+
output_type: OutputSpec[OutputDataT] = str,
67+
deps_type: type[AgentDepsT] = NoneType,
6068
trace_name: Optional[str] = None,
6169
trace_tags: Optional[List[str]] = None,
6270
trace_metadata: Optional[dict] = None,
@@ -88,10 +96,15 @@ def __init__(
8896
self.agent_metric_collection = agent_metric_collection
8997
self.agent_metrics = agent_metrics
9098

91-
super().__init__(*args, **kwargs)
99+
self.agent = Agent(
100+
*args,
101+
output_type=output_type,
102+
deps_type=deps_type,
103+
**kwargs
104+
)
92105

93106
patch_llm_model(
94-
self._model, llm_metric_collection, llm_metrics, llm_prompt
107+
self.agent.model, llm_metric_collection, llm_metrics, llm_prompt
95108
) # TODO: Add dual patch guards
96109

97110
async def run(
@@ -105,13 +118,13 @@ async def run(
105118
metrics: Optional[List[BaseMetric]] = None,
106119
metric_collection: Optional[str] = None,
107120
**kwargs
108-
):
109-
sig = inspect.signature(super().run)
121+
) -> AgentRunResult[OutputDataT]:
122+
sig = inspect.signature(self.agent.run)
110123
bound = sig.bind_partial(*args, **kwargs)
111124
bound.apply_defaults()
112125
input = bound.arguments.get("user_prompt", None)
113126

114-
agent_name = super().name if super().name is not None else "Agent"
127+
agent_name = self.agent.name if self.agent.name is not None else "Agent"
115128

116129
with Observer(
117130
span_type="agent" if not _IS_RUN_SYNC.get() else "custom",
@@ -122,7 +135,11 @@ async def run(
122135
self.agent_metric_collection if not _IS_RUN_SYNC.get() else None
123136
),
124137
) as observer:
125-
result = await super().run(*args, **kwargs)
138+
result = await self.agent.run(
139+
*args,
140+
**kwargs
141+
)
142+
126143
observer.result = result.output
127144
update_trace_context(
128145
trace_name=name if name is not None else self.trace_name,
@@ -169,15 +186,15 @@ def run_sync(
169186
metric_collection: Optional[str] = None,
170187
metrics: Optional[List[BaseMetric]] = None,
171188
**kwargs
172-
):
173-
sig = inspect.signature(super().run_sync)
189+
) -> AgentRunResult[OutputDataT]:
190+
sig = inspect.signature(self.agent.run_sync)
174191
bound = sig.bind_partial(*args, **kwargs)
175192
bound.apply_defaults()
176193
input = bound.arguments.get("user_prompt", None)
177194

178195
token = _IS_RUN_SYNC.set(True)
179196

180-
agent_name = super().name if super().name is not None else "Agent"
197+
agent_name = self.agent.name if self.agent.name is not None else "Agent"
181198

182199
with Observer(
183200
span_type="agent",
@@ -187,7 +204,7 @@ def run_sync(
187204
metric_collection=self.agent_metric_collection,
188205
) as observer:
189206
try:
190-
result = super().run_sync(*args, **kwargs)
207+
result = self.agent.run_sync(*args, **kwargs)
191208
finally:
192209
_IS_RUN_SYNC.reset(token)
193210

@@ -239,15 +256,15 @@ async def run_stream(
239256
metric_collection: Optional[str] = None,
240257
metrics: Optional[List[BaseMetric]] = None,
241258
**kwargs
242-
):
243-
sig = inspect.signature(super().run_stream)
259+
) -> AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]]:
260+
sig = inspect.signature(self.agent.run_stream)
244261
super_params = sig.parameters
245262
super_kwargs = {k: v for k, v in kwargs.items() if k in super_params}
246263
bound = sig.bind_partial(*args, **super_kwargs)
247264
bound.apply_defaults()
248265
input = bound.arguments.get("user_prompt", None)
249266

250-
agent_name = super().name if super().name is not None else "Agent"
267+
agent_name = self.agent.name if self.agent.name is not None else "Agent"
251268

252269
with Observer(
253270
span_type="agent",
@@ -257,7 +274,7 @@ async def run_stream(
257274
metric_collection=self.agent_metric_collection,
258275
) as observer:
259276
final_result = None
260-
async with super().run_stream(*args, **super_kwargs) as result:
277+
async with self.agent.run_stream(*args, **kwargs) as result:
261278
try:
262279
yield result
263280
finally:
@@ -319,21 +336,19 @@ def tool(
319336
metrics: Optional[List[BaseMetric]] = None,
320337
metric_collection: Optional[str] = None,
321338
**kwargs
322-
):
339+
) -> Any:
323340
# Direct decoration: @agent.tool
324341
if args and callable(args[0]):
325342
patched_func = create_patched_tool(
326343
args[0], metrics, metric_collection
327344
)
328345
new_args = (patched_func,) + args[1:]
329-
return super(DeepEvalPydanticAIAgent, self).tool(
346+
return self.agent.tool(
330347
*new_args, **kwargs
331348
)
332-
# Decoration with args: @agent.tool(...)
333-
super_tool = super(DeepEvalPydanticAIAgent, self).tool
334349

335350
def decorator(func):
336351
patched_func = create_patched_tool(func, metrics, metric_collection)
337-
return super_tool(*args, **kwargs)(patched_func)
352+
return self.agent.tool(*args, **kwargs)(patched_func)
338353

339354
return decorator

0 commit comments

Comments
 (0)