Skip to content

Commit 0ad4b60

Browse files
.
1 parent a7103de commit 0ad4b60

2 files changed

Lines changed: 107 additions & 5 deletions

File tree

deepeval/integrations/pydantic_ai/agent.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
from typing import Optional, List
33
from contextvars import ContextVar
4+
from contextlib import asynccontextmanager
45

56
from deepeval.prompt import Prompt
67
from deepeval.tracing.types import AgentSpan
@@ -96,10 +97,12 @@ async def run(
9697
bound = sig.bind_partial(*args, **kwargs)
9798
bound.apply_defaults()
9899
input = bound.arguments.get("user_prompt", None)
100+
101+
agent_name = super().name if super().name is not None else "Agent"
99102

100103
with Observer(
101104
span_type="agent" if not _IS_RUN_SYNC.get() else "custom",
102-
func_name="Agent" if not _IS_RUN_SYNC.get() else "run",
105+
func_name=agent_name if not _IS_RUN_SYNC.get() else "run",
103106
function_kwargs={"input": input},
104107
metrics=self.agent_metrics if not _IS_RUN_SYNC.get() else None,
105108
metric_collection=self.agent_metric_collection if not _IS_RUN_SYNC.get() else None,
@@ -149,9 +152,11 @@ def run_sync(
149152

150153
token = _IS_RUN_SYNC.set(True)
151154

155+
agent_name = super().name if super().name is not None else "Agent"
156+
152157
with Observer(
153158
span_type="agent",
154-
func_name="Agent",
159+
func_name=agent_name,
155160
function_kwargs={"input": input},
156161
metrics=self.agent_metrics,
157162
metric_collection=self.agent_metric_collection,
@@ -187,7 +192,64 @@ def run_sync(
187192

188193
return result
189194

190-
195+
@asynccontextmanager
196+
async def run_stream(
197+
self,
198+
*args,
199+
name: Optional[str] = None,
200+
tags: Optional[List[str]] = None,
201+
metadata: Optional[dict] = None,
202+
thread_id: Optional[str] = None,
203+
user_id: Optional[str] = None,
204+
metric_collection: Optional[str] = None,
205+
metrics: Optional[List[BaseMetric]] = None,
206+
**kwargs
207+
):
208+
sig = inspect.signature(super().run_stream)
209+
super_params = sig.parameters
210+
super_kwargs = {k: v for k, v in kwargs.items() if k in super_params}
211+
bound = sig.bind_partial(*args, **super_kwargs)
212+
bound.apply_defaults()
213+
input = bound.arguments.get("user_prompt", None)
214+
215+
agent_name = super().name if super().name is not None else "Agent"
216+
217+
with Observer(
218+
span_type="agent",
219+
func_name=agent_name,
220+
function_kwargs={"input": input},
221+
metrics=self.agent_metrics,
222+
metric_collection=self.agent_metric_collection,
223+
) as observer:
224+
final_result = None
225+
async with super().run_stream(*args, **super_kwargs) as result:
226+
try:
227+
yield result
228+
finally:
229+
try:
230+
final_result = await result.get_output()
231+
observer.result = final_result
232+
except Exception:
233+
pass
234+
235+
update_trace_context(
236+
trace_name=name if name is not None else self.trace_name,
237+
trace_tags=tags if tags is not None else self.trace_tags,
238+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
239+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
240+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
241+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
242+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
243+
trace_input=input,
244+
trace_output=(final_result if final_result is not None else None),
245+
)
246+
agent_span: AgentSpan = current_span_context.get()
247+
try:
248+
if final_result is not None:
249+
agent_span.tools_called = extract_tools_called(final_result)
250+
except:
251+
pass
252+
191253
def tool(
192254
self,
193255
*args,

deepeval/integrations/pydantic_ai/utils.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from time import perf_counter
2+
from contextlib import asynccontextmanager
13
import inspect
24
import functools
35
from typing import Any, Callable, List, Optional
@@ -11,7 +13,7 @@
1113
from deepeval.tracing.tracing import Observer
1214
from deepeval.metrics.base_metric import BaseMetric
1315
from deepeval.test_case.llm_test_case import ToolCall
14-
from deepeval.tracing.context import current_trace_context
16+
from deepeval.tracing.context import current_trace_context, current_span_context
1517
from deepeval.tracing.types import AgentSpan, LlmOutput, LlmSpan, LlmToolCall
1618

1719
# llm tools called
@@ -127,9 +129,47 @@ async def wrapper(*args, **kwargs):
127129
)
128130
observer.result = result
129131
return result
130-
132+
131133
model.request = wrapper
132134

135+
stream_original_func = model.request_stream
136+
stream_sig = inspect.signature(stream_original_func)
137+
138+
@asynccontextmanager
139+
async def stream_wrapper(*args, **kwargs):
140+
bound = stream_sig.bind_partial(*args, **kwargs)
141+
bound.apply_defaults()
142+
request = bound.arguments.get("messages", [])
143+
144+
with Observer(
145+
span_type="llm",
146+
func_name="LLM",
147+
observe_kwargs={"model": model_name},
148+
metrics=llm_metrics,
149+
metric_collection=llm_metric_collection,
150+
) as observer:
151+
llm_span: LlmSpan = current_span_context.get()
152+
async with stream_original_func(*args, **kwargs) as streamed_response:
153+
try:
154+
yield streamed_response
155+
print("streamed_response >>>>>")
156+
if not llm_span.token_intervals:
157+
llm_span.token_intervals = {perf_counter(): "NA"}
158+
else:
159+
llm_span.token_intervals[perf_counter()] = "NA"
160+
finally:
161+
try:
162+
result = streamed_response.get()
163+
observer.update_span_properties = (
164+
lambda llm_span: set_llm_span_attributes(
165+
llm_span, request, result, llm_prompt
166+
)
167+
)
168+
observer.result = result
169+
except Exception:
170+
pass
171+
172+
model.request_stream = stream_wrapper
133173

134174
def create_patched_tool(
135175
func: Callable,

0 commit comments

Comments
 (0)