-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathagent.py
More file actions
137 lines (119 loc) · 4.56 KB
/
agent.py
File metadata and controls
137 lines (119 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import asyncio
from datetime import datetime
from typing import Any, Dict, Optional
from copy import deepcopy
from google.adk import Agent
from google.adk.apps import App
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.tool_context import ToolContext
from google.genai import types
from toolbox_adk import CredentialStrategy, ToolboxToolset, ToolboxTool
SYSTEM_PROMPT = """
You're a helpful hotel assistant. You handle hotel searching, booking and
cancellations. When the user searches for a hotel, mention it's name, id,
location and price tier. Always mention hotel ids while performing any
searches. This is very important for any operations. For any bookings or
cancellations, please provide the appropriate confirmation. Be sure to
update checkin or checkout dates if mentioned by the user.
Don't ask for confirmations from the user.
"""
# Pre processing
async def enfore_business_rules(
tool: ToolboxTool, args: Dict[str, Any], tool_context: ToolContext
) -> Optional[Dict[str, Any]]:
"""
Callback fired before a tool is executed.
Enforces business logic: Max stay duration is 14 days.
"""
tool_name = tool.name
print(f"POLICY CHECK: Intercepting '{tool_name}'")
if tool_name == "update-hotel" and "checkin_date" in args and "checkout_date" in args:
start = datetime.fromisoformat(args["checkin_date"])
end = datetime.fromisoformat(args["checkout_date"])
duration = (end - start).days
if duration > 14:
print("BLOCKED: Stay too long")
return {"result": "Error: Maximum stay duration is 14 days."}
return None
# Post processing
async def enrich_response(
tool: ToolboxTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Any,
) -> Optional[Any]:
"""
Callback fired after a tool execution.
Enriches response for successful bookings.
"""
if isinstance(tool_response, dict):
result = tool_response.get("result", "")
elif isinstance(tool_response, str):
result = tool_response
else:
return None
tool_name = tool.name
if isinstance(result, str) and "Error" not in result:
if tool_name == "book-hotel":
loyalty_bonus = 500
enriched_result = f"Booking Confirmed!\n You earned {loyalty_bonus} Loyalty Points with this stay.\n\nSystem Details: {result}"
if isinstance(tool_response, dict):
modified_response = deepcopy(tool_response)
modified_response["result"] = enriched_result
return modified_response
else:
return enriched_result
return None
async def run_chat_turn(
runner: Runner, session_id: str, user_id: str, message_text: str
):
"""Executes a single chat turn and prints the interaction."""
print(f"\nUSER: '{message_text}'")
response_text = ""
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=types.Content(role="user", parts=[types.Part(text=message_text)]),
):
if event.content and event.content.parts:
for part in event.content.parts:
if part.text:
response_text += part.text
print(f"AI: {response_text}")
async def main():
toolset = ToolboxToolset(
server_url="http://127.0.0.1:5000",
toolset_name="my-toolset",
credentials=CredentialStrategy.toolbox_identity(),
)
tools = await toolset.get_tools()
root_agent = Agent(
name="root_agent",
model="gemini-2.5-flash",
instruction=SYSTEM_PROMPT,
tools=tools,
# add any pre and post processing callbacks
before_tool_callback=enfore_business_rules,
after_tool_callback=enrich_response,
)
app = App(root_agent=root_agent, name="my_agent")
runner = Runner(app=app, session_service=InMemorySessionService())
session_id = "test-session"
user_id = "test-user"
await runner.session_service.create_session(
app_name=app.name, user_id=user_id, session_id=session_id
)
# First turn: Successful booking
await run_chat_turn(runner, session_id, user_id, "Book hotel with id 3.")
print("-" * 50)
# Second turn: Policy violation (stay > 14 days)
await run_chat_turn(
runner,
session_id,
user_id,
"Book a hotel with id 5 with checkin date 2025-01-18 and checkout date 2025-02-10",
)
await toolset.close()
if __name__ == "__main__":
asyncio.run(main())