Skip to content

Commit 2b1a13e

Browse files
.
1 parent f00cd91 commit 2b1a13e

4 files changed

Lines changed: 218 additions & 6 deletions

File tree

deepeval/integrations/pydantic_ai/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Optional, List
2+
from typing import Optional, List, Generic, TypeVar
33
from contextvars import ContextVar
44
from contextlib import asynccontextmanager
55

@@ -12,6 +12,8 @@
1212

1313
try:
1414
from pydantic_ai.agent import Agent
15+
from pydantic_ai.tools import AgentDepsT
16+
from pydantic_ai.output import OutputDataT
1517
from deepeval.integrations.pydantic_ai.utils import create_patched_tool, update_trace_context, patch_llm_model
1618
is_pydantic_ai_installed = True
1719
except:
@@ -25,7 +27,7 @@ def pydantic_ai_installed():
2527

2628
_IS_RUN_SYNC = ContextVar("deepeval_is_run_sync", default=False)
2729

28-
class DeepEvalPydanticAIAgent(Agent):
30+
class DeepEvalPydanticAIAgent(Agent[AgentDepsT, OutputDataT], Generic[AgentDepsT, OutputDataT]):
2931

3032
trace_name: Optional[str] = None
3133
trace_tags: Optional[List[str]] = None

deepeval/integrations/pydantic_ai/patcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def instrument(otel: Optional[bool] = False, api_key: Optional[str] = None):
5353
Please deepeval.integrations.pydantic_ai.Agent to instrument instead.
5454
"""
5555
warnings.warn(
56-
"The 'instrument_pydantic_ai()' function is deprecated and will be removed in a future version.",
57-
"Please use deepeval.integrations.pydantic_ai.Agent to instrument instead. Refer to the documenation [link]", #TODO: add the link,
58-
DeprecationWarning,
56+
"The 'instrument_pydantic_ai()' function is deprecated and will be removed in a future version. "
57+
"Please use deepeval.integrations.pydantic_ai.Agent to instrument instead. Refer to the documentation [link]", #TODO: add the link,
58+
UserWarning,
5959
stacklevel=2
6060
)
6161

deepeval/integrations/pydantic_ai/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ async def stream_wrapper(*args, **kwargs):
152152
async with stream_original_func(*args, **kwargs) as streamed_response:
153153
try:
154154
yield streamed_response
155-
print("streamed_response >>>>>")
156155
if not llm_span.token_intervals:
157156
llm_span.token_intervals = {perf_counter(): "NA"}
158157
else:
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# from pydantic_ai import RunContext
2+
# import asyncio
3+
4+
# from deepeval.integrations.pydantic_ai import Agent
5+
6+
# joke_selection_agent = Agent(
7+
# 'openai:gpt-4o',
8+
# system_prompt=(
9+
# 'Use the `joke_factory` to generate some jokes, then choose the best. '
10+
# 'You must return just a single joke.'
11+
# ),
12+
# trace_name="joke_selection_agent",
13+
# )
14+
# joke_generation_agent = Agent(
15+
# 'openai:gpt-4o', output_type=list[str],
16+
# )
17+
18+
19+
# @joke_selection_agent.tool
20+
# async def joke_factory(ctx: RunContext[None], count: int) -> list[str]:
21+
# r = await joke_generation_agent.run(
22+
# f'Please generate {count} jokes.',
23+
# usage=ctx.usage,
24+
# )
25+
# return r.output
26+
27+
# async def execute_agent():
28+
# result = await joke_selection_agent.run('Tell me a joke.', name="joke_selection_agent_2")
29+
# print(result.output)
30+
31+
# asyncio.run(execute_agent())
32+
# result = joke_selection_agent.run_sync('Tell me a joke.')
33+
# print(result.output)
34+
#> Did you hear about the toothpaste scandal? They called it Colgate.
35+
36+
37+
########################################################
38+
39+
# from dataclasses import dataclass
40+
# import asyncio
41+
# import httpx
42+
43+
# from pydantic_ai import RunContext
44+
# from deepeval.integrations.pydantic_ai import Agent
45+
46+
47+
# @dataclass
48+
# class ClientAndKey:
49+
# http_client: httpx.AsyncClient
50+
# api_key: str
51+
52+
53+
# joke_selection_agent = Agent(
54+
# 'openai:gpt-4o',
55+
# deps_type=ClientAndKey,
56+
# system_prompt=(
57+
# 'Use the `joke_factory` tool to generate some jokes on the given subject, '
58+
# 'then choose the best. You must return just a single joke.'
59+
# ),
60+
# )
61+
# joke_generation_agent = Agent(
62+
# 'openai:gpt-4o',
63+
# deps_type=ClientAndKey,
64+
# output_type=list[str],
65+
# system_prompt=(
66+
# 'Use the "get_jokes" tool to get some jokes on the given subject, '
67+
# 'then extract each joke into a list.'
68+
# ),
69+
# )
70+
71+
72+
# @joke_selection_agent.tool
73+
# async def joke_factory(ctx: RunContext[ClientAndKey], count: int) -> list[str]:
74+
# r = await joke_generation_agent.run(
75+
# f'Please generate {count} jokes.',
76+
# deps=ctx.deps,
77+
# usage=ctx.usage,
78+
# )
79+
# return r.output
80+
81+
82+
# @joke_generation_agent.tool
83+
# async def get_jokes(ctx: RunContext[ClientAndKey], count: int) -> str:
84+
# response = await ctx.deps.http_client.get(
85+
# 'https://example.com',
86+
# params={'count': count},
87+
# headers={'Authorization': f'Bearer {ctx.deps.api_key}'},
88+
# )
89+
# response.raise_for_status()
90+
# return response.text
91+
92+
93+
# async def main():
94+
# async with httpx.AsyncClient() as client:
95+
# deps = ClientAndKey(client, 'foobar')
96+
# result = await joke_selection_agent.run('Tell me a joke.', deps=deps)
97+
# print(result.output)
98+
# #> Did you hear about the toothpaste scandal? They called it Colgate.
99+
# # print(result.usage())
100+
# #> RunUsage(input_tokens=309, output_tokens=32, requests=4, tool_calls=2)
101+
102+
# asyncio.run(main())
103+
104+
105+
106+
from typing import Literal
107+
108+
from pydantic import BaseModel, Field
109+
from rich.prompt import Prompt
110+
111+
from pydantic_ai import RunContext
112+
from deepeval.integrations.pydantic_ai import Agent, instrument_pydantic_ai
113+
from pydantic_ai.messages import ModelMessage
114+
115+
instrument_pydantic_ai()
116+
117+
118+
class FlightDetails(BaseModel):
119+
flight_number: str
120+
121+
122+
class Failed(BaseModel):
123+
"""Unable to find a satisfactory choice."""
124+
125+
126+
flight_search_agent = Agent[None, FlightDetails | Failed](
127+
'openai:gpt-4o',
128+
name="flight_search_agent",
129+
output_type=FlightDetails | Failed, # type: ignore
130+
system_prompt=(
131+
'Use the "flight_search" tool to find a flight '
132+
'from the given origin to the given destination.'
133+
),
134+
)
135+
136+
137+
@flight_search_agent.tool
138+
async def flight_search(
139+
ctx: RunContext[None], origin: str, destination: str
140+
) -> FlightDetails | None:
141+
# in reality, this would call a flight search API or
142+
# use a browser to scrape a flight search website
143+
return FlightDetails(flight_number='AK456')
144+
145+
146+
147+
async def find_flight() -> FlightDetails | None:
148+
message_history: list[ModelMessage] | None = None
149+
for _ in range(3):
150+
prompt = Prompt.ask(
151+
'Where would you like to fly from and to?',
152+
)
153+
result = await flight_search_agent.run(
154+
prompt,
155+
message_history=message_history,
156+
)
157+
if isinstance(result.output, FlightDetails):
158+
return result.output
159+
else:
160+
message_history = result.all_messages(
161+
output_tool_return_content='Please try again.'
162+
)
163+
164+
165+
class SeatPreference(BaseModel):
166+
row: int = Field(ge=1, le=30)
167+
seat: Literal['A', 'B', 'C', 'D', 'E', 'F']
168+
169+
170+
# This agent is responsible for extracting the user's seat selection
171+
seat_preference_agent = Agent[None, SeatPreference | Failed](
172+
'openai:gpt-4o',
173+
name="seat_preference_agent",
174+
output_type=SeatPreference | Failed, # type: ignore
175+
system_prompt=(
176+
"Extract the user's seat preference. "
177+
'Seats A and F are window seats. '
178+
'Row 1 is the front row and has extra leg room. '
179+
'Rows 14, and 20 also have extra leg room. '
180+
),
181+
)
182+
183+
184+
async def find_seat() -> SeatPreference:
185+
message_history: list[ModelMessage] | None = None
186+
while True:
187+
answer = Prompt.ask('What seat would you like?')
188+
189+
result = await seat_preference_agent.run(
190+
answer,
191+
message_history=message_history,
192+
)
193+
if isinstance(result.output, SeatPreference):
194+
return result.output
195+
else:
196+
print('Could not understand seat preference. Please try again.')
197+
message_history = result.all_messages()
198+
199+
200+
async def main():
201+
202+
opt_flight_details = await find_flight()
203+
if opt_flight_details is not None:
204+
print(f'Flight found: {opt_flight_details.flight_number}')
205+
#> Flight found: AK456
206+
seat_preference = await find_seat()
207+
print(f'Seat preference: {seat_preference}')
208+
#> Seat preference: row=1 seat='A'
209+
210+
# import asyncio
211+
# asyncio.run(main())

0 commit comments

Comments
 (0)