Skip to content

Commit f1f5190

Browse files
.
1 parent c41e89b commit f1f5190

3 files changed

Lines changed: 61 additions & 49 deletions

File tree

deepeval/integrations/pydantic_ai/agent.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ def __init__(
8484
async def run(
8585
self,
8686
*args,
87-
trace_name: Optional[str] = None,
88-
trace_tags: Optional[List[str]] = None,
89-
trace_user_id: Optional[str] = None,
90-
trace_metadata: Optional[dict] = None,
91-
trace_thread_id: Optional[str] = None,
92-
trace_metrics: Optional[List[BaseMetric]] = None,
93-
trace_metric_collection: Optional[str] = None,
87+
88+
name: Optional[str] = None,
89+
tags: Optional[List[str]] = None,
90+
user_id: Optional[str] = None,
91+
metadata: Optional[dict] = None,
92+
thread_id: Optional[str] = None,
93+
metrics: Optional[List[BaseMetric]] = None,
94+
metric_collection: Optional[str] = None,
95+
9496
**kwargs
9597
):
9698
sig = inspect.signature(super().run)
@@ -111,13 +113,13 @@ async def run(
111113
observer.result = result.output
112114
update_trace_context(
113115

114-
trace_name=trace_name if trace_name is not None else self.trace_name,
115-
trace_tags=trace_tags if trace_tags is not None else self.trace_tags,
116-
trace_metadata=trace_metadata if trace_metadata is not None else self.trace_metadata,
117-
trace_thread_id=trace_thread_id if trace_thread_id is not None else self.trace_thread_id,
118-
trace_user_id=trace_user_id if trace_user_id is not None else self.trace_user_id,
119-
trace_metric_collection=trace_metric_collection if trace_metric_collection is not None else self.trace_metric_collection,
120-
trace_metrics=trace_metrics if trace_metrics is not None else self.trace_metrics,
116+
trace_name=name if name is not None else self.trace_name,
117+
trace_tags=tags if tags is not None else self.trace_tags,
118+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
119+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
120+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
121+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
122+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
121123

122124
trace_input=input,
123125
trace_output=result.output,
@@ -134,15 +136,17 @@ async def run(
134136
return result
135137

136138
def run_sync(
137-
self,
139+
self,
138140
*args,
139-
trace_name: Optional[str] = None,
140-
trace_tags: Optional[List[str]] = None,
141-
trace_metadata: Optional[dict] = None,
142-
trace_thread_id: Optional[str] = None,
143-
trace_user_id: Optional[str] = None,
144-
trace_metric_collection: Optional[str] = None,
145-
trace_metrics: Optional[List[BaseMetric]] = None,
141+
142+
name: Optional[str] = None,
143+
tags: Optional[List[str]] = None,
144+
metadata: Optional[dict] = None,
145+
thread_id: Optional[str] = None,
146+
user_id: Optional[str] = None,
147+
metric_collection: Optional[str] = None,
148+
metrics: Optional[List[BaseMetric]] = None,
149+
146150
**kwargs
147151
):
148152
sig = inspect.signature(super().run_sync)
@@ -169,13 +173,13 @@ def run_sync(
169173
observer.result = result.output
170174
update_trace_context(
171175

172-
trace_name=trace_name if trace_name is not None else self.trace_name,
173-
trace_tags=trace_tags if trace_tags is not None else self.trace_tags,
174-
trace_metadata=trace_metadata if trace_metadata is not None else self.trace_metadata,
175-
trace_thread_id=trace_thread_id if trace_thread_id is not None else self.trace_thread_id,
176-
trace_user_id=trace_user_id if trace_user_id is not None else self.trace_user_id,
177-
trace_metric_collection=trace_metric_collection if trace_metric_collection is not None else self.trace_metric_collection,
178-
trace_metrics=trace_metrics if trace_metrics is not None else self.trace_metrics,
176+
trace_name=name if name is not None else self.trace_name,
177+
trace_tags=tags if tags is not None else self.trace_tags,
178+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
179+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
180+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
181+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
182+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
179183

180184
trace_input=input,
181185
trace_output=result.output,
@@ -196,13 +200,15 @@ def run_sync(
196200
async def run_stream(
197201
self,
198202
*args,
199-
trace_name: Optional[str] = None,
200-
trace_tags: Optional[List[str]] = None,
201-
trace_metadata: Optional[dict] = None,
202-
trace_thread_id: Optional[str] = None,
203-
trace_user_id: Optional[str] = None,
204-
trace_metric_collection: Optional[str] = None,
205-
trace_metrics: Optional[List[BaseMetric]] = None,
203+
204+
name: Optional[str] = None,
205+
tags: Optional[List[str]] = None,
206+
metadata: Optional[dict] = None,
207+
thread_id: Optional[str] = None,
208+
user_id: Optional[str] = None,
209+
metric_collection: Optional[str] = None,
210+
metrics: Optional[List[BaseMetric]] = None,
211+
206212
**kwargs
207213
):
208214
sig = inspect.signature(super().run_stream)
@@ -233,13 +239,15 @@ async def run_stream(
233239
pass
234240

235241
update_trace_context(
236-
trace_name=trace_name if trace_name is not None else self.trace_name,
237-
trace_tags=trace_tags if trace_tags is not None else self.trace_tags,
238-
trace_metadata=trace_metadata if trace_metadata is not None else self.trace_metadata,
239-
trace_thread_id=trace_thread_id if trace_thread_id is not None else self.trace_thread_id,
240-
trace_user_id=trace_user_id if trace_user_id is not None else self.trace_user_id,
241-
trace_metric_collection=trace_metric_collection if trace_metric_collection is not None else self.trace_metric_collection,
242-
trace_metrics=trace_metrics if trace_metrics is not None else self.trace_metrics,
242+
243+
trace_name=name if name is not None else self.trace_name,
244+
trace_tags=tags if tags is not None else self.trace_tags,
245+
trace_metadata=metadata if metadata is not None else self.trace_metadata,
246+
trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id,
247+
trace_user_id=user_id if user_id is not None else self.trace_user_id,
248+
trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection,
249+
trace_metrics=metrics if metrics is not None else self.trace_metrics,
250+
243251
trace_input=input,
244252
trace_output=(final_result if final_result is not None else None),
245253
)

tests/test_integrations/test_pydanticai/pydantic_all_tests.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@ def get_weather(city: str) -> str:
2727
)
2828

2929
async def execute_agent_stream():
30-
async with agent.run_stream("What is the weather in London?", trace_name="test_name_2") as result:
30+
async with agent.run_stream("What is the weather in London?", name="test_name_2") as result:
3131
async for chunk in result.stream_text(delta=True):
3232
print(chunk, end="", flush=True)
3333
final = await result.get_output()
3434
print("\n\nFinal:", final)
3535

3636
async def execute_agent_run():
37-
result = await agent.run("What is the weather in London?", trace_name="test_name_4")
37+
result = await agent.run("What is the weather in London?", name="test_name_4")
3838
print(result.output)
3939

4040
def execute_all():
4141
asyncio.run(execute_agent_stream())
42-
agent.run_sync("What is the weather in London?", trace_name="test_name_3")
43-
asyncio.run(execute_agent_run())
42+
agent.run_sync("What is the weather in London?", name="test_name_3")
43+
asyncio.run(execute_agent_run())
44+
45+
execute_all()

tests/test_integrations/test_pydanticai/pydanticai_app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ async def run_agent(input_query: str):
8282
input_query,
8383
deps=deps,
8484
# metric_collection="test_collection_1",
85-
trace_name="test_trace_2",
85+
name="test_trace_2",
8686
# tags=["test_tag_1"],
8787
# metadata={"test_metadata_1": "test_metadata_1"},
8888
# thread_id="test_thread_id_1",
89-
trace_user_id="test_user_id_2",
89+
user_id="test_user_id_2",
9090
)
9191

9292
return result.output
@@ -95,3 +95,5 @@ async def run_agent(input_query: str):
9595
def execute_agent():
9696
output = asyncio.run(run_agent("What's the weather in Paris?"))
9797
return output
98+
99+
execute_agent()

0 commit comments

Comments
 (0)