1+ # from pydantic_ai import RunContext
2+ # import asyncio
3+
4+ # from deepeval.integrations.pydantic_ai import Agent
5+
6+ # joke_selection_agent = Agent(
7+ # 'openai:gpt-4o',
8+ # system_prompt=(
9+ # 'Use the `joke_factory` to generate some jokes, then choose the best. '
10+ # 'You must return just a single joke.'
11+ # ),
12+ # trace_name="joke_selection_agent",
13+ # )
14+ # joke_generation_agent = Agent(
15+ # 'openai:gpt-4o', output_type=list[str],
16+ # )
17+
18+
19+ # @joke_selection_agent.tool
20+ # async def joke_factory(ctx: RunContext[None], count: int) -> list[str]:
21+ # r = await joke_generation_agent.run(
22+ # f'Please generate {count} jokes.',
23+ # usage=ctx.usage,
24+ # )
25+ # return r.output
26+
27+ # async def execute_agent():
28+ # result = await joke_selection_agent.run('Tell me a joke.', name="joke_selection_agent_2")
29+ # print(result.output)
30+
31+ # asyncio.run(execute_agent())
32+ # result = joke_selection_agent.run_sync('Tell me a joke.')
33+ # print(result.output)
34+ #> Did you hear about the toothpaste scandal? They called it Colgate.
35+
36+
37+ ########################################################
38+
39+ # from dataclasses import dataclass
40+ # import asyncio
41+ # import httpx
42+
43+ # from pydantic_ai import RunContext
44+ # from deepeval.integrations.pydantic_ai import Agent
45+
46+
47+ # @dataclass
48+ # class ClientAndKey:
49+ # http_client: httpx.AsyncClient
50+ # api_key: str
51+
52+
53+ # joke_selection_agent = Agent(
54+ # 'openai:gpt-4o',
55+ # deps_type=ClientAndKey,
56+ # system_prompt=(
57+ # 'Use the `joke_factory` tool to generate some jokes on the given subject, '
58+ # 'then choose the best. You must return just a single joke.'
59+ # ),
60+ # )
61+ # joke_generation_agent = Agent(
62+ # 'openai:gpt-4o',
63+ # deps_type=ClientAndKey,
64+ # output_type=list[str],
65+ # system_prompt=(
66+ # 'Use the "get_jokes" tool to get some jokes on the given subject, '
67+ # 'then extract each joke into a list.'
68+ # ),
69+ # )
70+
71+
72+ # @joke_selection_agent.tool
73+ # async def joke_factory(ctx: RunContext[ClientAndKey], count: int) -> list[str]:
74+ # r = await joke_generation_agent.run(
75+ # f'Please generate {count} jokes.',
76+ # deps=ctx.deps,
77+ # usage=ctx.usage,
78+ # )
79+ # return r.output
80+
81+
82+ # @joke_generation_agent.tool
83+ # async def get_jokes(ctx: RunContext[ClientAndKey], count: int) -> str:
84+ # response = await ctx.deps.http_client.get(
85+ # 'https://example.com',
86+ # params={'count': count},
87+ # headers={'Authorization': f'Bearer {ctx.deps.api_key}'},
88+ # )
89+ # response.raise_for_status()
90+ # return response.text
91+
92+
93+ # async def main():
94+ # async with httpx.AsyncClient() as client:
95+ # deps = ClientAndKey(client, 'foobar')
96+ # result = await joke_selection_agent.run('Tell me a joke.', deps=deps)
97+ # print(result.output)
98+ # #> Did you hear about the toothpaste scandal? They called it Colgate.
99+ # # print(result.usage())
100+ # #> RunUsage(input_tokens=309, output_tokens=32, requests=4, tool_calls=2)
101+
102+ # asyncio.run(main())
103+
104+
105+
106+ from typing import Literal
107+
108+ from pydantic import BaseModel , Field
109+ from rich .prompt import Prompt
110+
111+ from pydantic_ai import RunContext
112+ from deepeval .integrations .pydantic_ai import Agent , instrument_pydantic_ai
113+ from pydantic_ai .messages import ModelMessage
114+
115+ instrument_pydantic_ai ()
116+
117+
118+ class FlightDetails (BaseModel ):
119+ flight_number : str
120+
121+
122+ class Failed (BaseModel ):
123+ """Unable to find a satisfactory choice."""
124+
125+
126+ flight_search_agent = Agent [None , FlightDetails | Failed ](
127+ 'openai:gpt-4o' ,
128+ name = "flight_search_agent" ,
129+ output_type = FlightDetails | Failed , # type: ignore
130+ system_prompt = (
131+ 'Use the "flight_search" tool to find a flight '
132+ 'from the given origin to the given destination.'
133+ ),
134+ )
135+
136+
137+ @flight_search_agent .tool
138+ async def flight_search (
139+ ctx : RunContext [None ], origin : str , destination : str
140+ ) -> FlightDetails | None :
141+ # in reality, this would call a flight search API or
142+ # use a browser to scrape a flight search website
143+ return FlightDetails (flight_number = 'AK456' )
144+
145+
146+
147+ async def find_flight () -> FlightDetails | None :
148+ message_history : list [ModelMessage ] | None = None
149+ for _ in range (3 ):
150+ prompt = Prompt .ask (
151+ 'Where would you like to fly from and to?' ,
152+ )
153+ result = await flight_search_agent .run (
154+ prompt ,
155+ message_history = message_history ,
156+ )
157+ if isinstance (result .output , FlightDetails ):
158+ return result .output
159+ else :
160+ message_history = result .all_messages (
161+ output_tool_return_content = 'Please try again.'
162+ )
163+
164+
165+ class SeatPreference (BaseModel ):
166+ row : int = Field (ge = 1 , le = 30 )
167+ seat : Literal ['A' , 'B' , 'C' , 'D' , 'E' , 'F' ]
168+
169+
170+ # This agent is responsible for extracting the user's seat selection
171+ seat_preference_agent = Agent [None , SeatPreference | Failed ](
172+ 'openai:gpt-4o' ,
173+ name = "seat_preference_agent" ,
174+ output_type = SeatPreference | Failed , # type: ignore
175+ system_prompt = (
176+ "Extract the user's seat preference. "
177+ 'Seats A and F are window seats. '
178+ 'Row 1 is the front row and has extra leg room. '
179+ 'Rows 14, and 20 also have extra leg room. '
180+ ),
181+ )
182+
183+
184+ async def find_seat () -> SeatPreference :
185+ message_history : list [ModelMessage ] | None = None
186+ while True :
187+ answer = Prompt .ask ('What seat would you like?' )
188+
189+ result = await seat_preference_agent .run (
190+ answer ,
191+ message_history = message_history ,
192+ )
193+ if isinstance (result .output , SeatPreference ):
194+ return result .output
195+ else :
196+ print ('Could not understand seat preference. Please try again.' )
197+ message_history = result .all_messages ()
198+
199+
200+ async def main ():
201+
202+ opt_flight_details = await find_flight ()
203+ if opt_flight_details is not None :
204+ print (f'Flight found: { opt_flight_details .flight_number } ' )
205+ #> Flight found: AK456
206+ seat_preference = await find_seat ()
207+ print (f'Seat preference: { seat_preference } ' )
208+ #> Seat preference: row=1 seat='A'
209+
210+ # import asyncio
211+ # asyncio.run(main())
0 commit comments