Skip to content

Commit 354189e

Browse files
authored
Merge pull request #2071 from confident-ai/mayank/pydantic_fixes_asap
pydantic ai fixes
2 parents 2b1998a + 2b1a13e commit 354189e

10 files changed

Lines changed: 1276 additions & 465 deletions

File tree

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .agent import DeepEvalPydanticAIAgent as Agent
12
from .patcher import instrument as instrument_pydantic_ai
3+
from .otel import instrument_pydantic_ai as otel_instrument_pydantic_ai
24

3-
__all__ = ["instrument_pydantic_ai"]
5+
__all__ = ["instrument_pydantic_ai", "Agent", otel_instrument_pydantic_ai]
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import inspect
2+
from typing import Optional, List, Generic, TypeVar
3+
from contextvars import ContextVar
4+
from contextlib import asynccontextmanager
5+
6+
from deepeval.prompt import Prompt
7+
from deepeval.tracing.types import AgentSpan
8+
from deepeval.tracing.tracing import Observer
9+
from deepeval.metrics.base_metric import BaseMetric
10+
from deepeval.tracing.context import current_span_context
11+
from deepeval.integrations.pydantic_ai.utils import extract_tools_called
12+
13+
try:
14+
from pydantic_ai.agent import Agent
15+
from pydantic_ai.tools import AgentDepsT
16+
from pydantic_ai.output import OutputDataT
17+
from deepeval.integrations.pydantic_ai.utils import create_patched_tool, update_trace_context, patch_llm_model
18+
is_pydantic_ai_installed = True
19+
except:
20+
is_pydantic_ai_installed = False
21+
22+
def pydantic_ai_installed():
23+
if not is_pydantic_ai_installed:
24+
raise ImportError(
25+
"Pydantic AI is not installed. Please install it with `pip install pydantic-ai`."
26+
)
27+
28+
_IS_RUN_SYNC = ContextVar("deepeval_is_run_sync", default=False)
29+
30+
class DeepEvalPydanticAIAgent(Agent[AgentDepsT, OutputDataT], Generic[AgentDepsT, OutputDataT]):
31+
32+
trace_name: Optional[str] = None
33+
trace_tags: Optional[List[str]] = None
34+
trace_metadata: Optional[dict] = None
35+
trace_thread_id: Optional[str] = None
36+
trace_user_id: Optional[str] = None
37+
trace_metric_collection: Optional[str] = None
38+
trace_metrics: Optional[List[BaseMetric]] = None
39+
40+
llm_prompt: Optional[Prompt] = None
41+
llm_metrics: Optional[List[BaseMetric]] = None
42+
llm_metric_collection: Optional[str] = None
43+
44+
agent_metrics: Optional[List[BaseMetric]] = None
45+
agent_metric_collection: Optional[str] = None
46+
47+
def __init__(
48+
self,
49+
*args,
50+
trace_name: Optional[str] = None,
51+
trace_tags: Optional[List[str]] = None,
52+
trace_metadata: Optional[dict] = None,
53+
trace_thread_id: Optional[str] = None,
54+
trace_user_id: Optional[str] = None,
55+
trace_metric_collection: Optional[str] = None,
56+
trace_metrics: Optional[List[BaseMetric]] = None,
57+
llm_metric_collection: Optional[str] = None,
58+
llm_metrics: Optional[List[BaseMetric]] = None,
59+
llm_prompt: Optional[Prompt] = None,
60+
agent_metric_collection: Optional[str] = None,
61+
agent_metrics: Optional[List[BaseMetric]] = None,
62+
**kwargs
63+
):
64+
pydantic_ai_installed()
65+
66+
self.trace_name = trace_name
67+
self.trace_tags = trace_tags
68+
self.trace_metadata = trace_metadata
69+
self.trace_thread_id = trace_thread_id
70+
self.trace_user_id = trace_user_id
71+
self.trace_metric_collection = trace_metric_collection
72+
self.trace_metrics = trace_metrics
73+
74+
self.llm_metric_collection = llm_metric_collection
75+
self.llm_metrics = llm_metrics
76+
self.llm_prompt = llm_prompt
77+
78+
self.agent_metric_collection = agent_metric_collection
79+
self.agent_metrics = agent_metrics
80+
81+
super().__init__(*args, **kwargs)
82+
83+
patch_llm_model(self._model, llm_metric_collection, llm_metrics, llm_prompt) #TODO: Add dual patch guards
84+
85+
86+
async def run(
87+
self,
88+
*args,
89+
90+
name: Optional[str] = None,
91+
tags: Optional[List[str]] = None,
92+
user_id: Optional[str] = None,
93+
metadata: Optional[dict] = None,
94+
thread_id: Optional[str] = None,
95+
metrics: Optional[List[BaseMetric]] = None,
96+
metric_collection: Optional[str] = None,
97+
98+
**kwargs
99+
):
100+
sig = inspect.signature(super().run)
101+
bound = sig.bind_partial(*args, **kwargs)
102+
bound.apply_defaults()
103+
input = bound.arguments.get("user_prompt", None)
104+
105+
agent_name = super().name if super().name is not None else "Agent"
106+
107+
with Observer(
108+
span_type="agent" if not _IS_RUN_SYNC.get() else "custom",
109+
func_name=agent_name if not _IS_RUN_SYNC.get() else "run",
110+
function_kwargs={"input": input},
111+
metrics=self.agent_metrics if not _IS_RUN_SYNC.get() else None,
112+
metric_collection=self.agent_metric_collection if not _IS_RUN_SYNC.get() else None,
113+
) as observer:
114+
result = await super().run(*args, **kwargs)
115+
observer.result = result.output
116+
update_trace_context(
117+
118+
trace_name=name if name is not None else self.trace_name,
119+
trace_tags=tags if tags is not None else self.trace_tags,
120+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
121+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
122+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
123+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
124+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
125+
126+
trace_input=input,
127+
trace_output=result.output,
128+
)
129+
130+
agent_span: AgentSpan = current_span_context.get()
131+
try:
132+
agent_span.tools_called = extract_tools_called(result)
133+
except:
134+
pass
135+
# TODO: available tools
136+
# TODO: agent handoffs
137+
138+
return result
139+
140+
def run_sync(
141+
self,
142+
*args,
143+
144+
name: Optional[str] = None,
145+
tags: Optional[List[str]] = None,
146+
metadata: Optional[dict] = None,
147+
thread_id: Optional[str] = None,
148+
user_id: Optional[str] = None,
149+
metric_collection: Optional[str] = None,
150+
metrics: Optional[List[BaseMetric]] = None,
151+
152+
**kwargs
153+
):
154+
sig = inspect.signature(super().run_sync)
155+
bound = sig.bind_partial(*args, **kwargs)
156+
bound.apply_defaults()
157+
input = bound.arguments.get("user_prompt", None)
158+
159+
token = _IS_RUN_SYNC.set(True)
160+
161+
agent_name = super().name if super().name is not None else "Agent"
162+
163+
with Observer(
164+
span_type="agent",
165+
func_name=agent_name,
166+
function_kwargs={"input": input},
167+
metrics=self.agent_metrics,
168+
metric_collection=self.agent_metric_collection,
169+
) as observer:
170+
try:
171+
result = super().run_sync(*args, **kwargs)
172+
finally:
173+
_IS_RUN_SYNC.reset(token)
174+
175+
observer.result = result.output
176+
update_trace_context(
177+
178+
trace_name=name if name is not None else self.trace_name,
179+
trace_tags=tags if tags is not None else self.trace_tags,
180+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
181+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
182+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
183+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
184+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
185+
186+
trace_input=input,
187+
trace_output=result.output,
188+
)
189+
190+
agent_span: AgentSpan = current_span_context.get()
191+
try:
192+
agent_span.tools_called = extract_tools_called(result)
193+
except:
194+
pass
195+
196+
# TODO: available tools
197+
# TODO: agent handoffs
198+
199+
return result
200+
201+
@asynccontextmanager
202+
async def run_stream(
203+
self,
204+
*args,
205+
206+
name: Optional[str] = None,
207+
tags: Optional[List[str]] = None,
208+
metadata: Optional[dict] = None,
209+
thread_id: Optional[str] = None,
210+
user_id: Optional[str] = None,
211+
metric_collection: Optional[str] = None,
212+
metrics: Optional[List[BaseMetric]] = None,
213+
214+
**kwargs
215+
):
216+
sig = inspect.signature(super().run_stream)
217+
super_params = sig.parameters
218+
super_kwargs = {k: v for k, v in kwargs.items() if k in super_params}
219+
bound = sig.bind_partial(*args, **super_kwargs)
220+
bound.apply_defaults()
221+
input = bound.arguments.get("user_prompt", None)
222+
223+
agent_name = super().name if super().name is not None else "Agent"
224+
225+
with Observer(
226+
span_type="agent",
227+
func_name=agent_name,
228+
function_kwargs={"input": input},
229+
metrics=self.agent_metrics,
230+
metric_collection=self.agent_metric_collection,
231+
) as observer:
232+
final_result = None
233+
async with super().run_stream(*args, **super_kwargs) as result:
234+
try:
235+
yield result
236+
finally:
237+
try:
238+
final_result = await result.get_output()
239+
observer.result = final_result
240+
except Exception:
241+
pass
242+
243+
update_trace_context(
244+
245+
trace_name=name if name is not None else self.trace_name,
246+
trace_tags=tags if tags is not None else self.trace_tags,
247+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
248+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
249+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
250+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
251+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
252+
253+
trace_input=input,
254+
trace_output=(final_result if final_result is not None else None),
255+
)
256+
agent_span: AgentSpan = current_span_context.get()
257+
try:
258+
if final_result is not None:
259+
agent_span.tools_called = extract_tools_called(final_result)
260+
except:
261+
pass
262+
263+
def tool(
264+
self,
265+
*args,
266+
metrics: Optional[List[BaseMetric]] = None,
267+
metric_collection: Optional[str] = None,
268+
**kwargs
269+
):
270+
# Direct decoration: @agent.tool
271+
if args and callable(args[0]):
272+
patched_func = create_patched_tool(args[0], metrics, metric_collection)
273+
new_args = (patched_func,) + args[1:]
274+
return super(DeepEvalPydanticAIAgent, self).tool(*new_args, **kwargs)
275+
# Decoration with args: @agent.tool(...)
276+
super_tool = super(DeepEvalPydanticAIAgent, self).tool
277+
def decorator(func):
278+
patched_func = create_patched_tool(func, metrics, metric_collection)
279+
return super_tool(*args, **kwargs)(patched_func)
280+
return decorator

deepeval/integrations/pydantic_ai/otel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def is_opentelemetry_available():
2626

2727
OTLP_ENDPOINT = "https://otel.confident-ai.com/v1/traces"
2828

29-
3029
def instrument_pydantic_ai(api_key: Optional[str] = None):
3130
with capture_tracing_integration("pydantic_ai"):
3231
is_opentelemetry_available()

0 commit comments

Comments
 (0)