Skip to content

Commit f9d477a

Browse files
authored
chore: set up basic a2a protocol binding into flare-ai-kit (#76)
2 parents c665e92 + 4b740a0 commit f9d477a

File tree

14 files changed

+3362
-1122
lines changed

14 files changed

+3362
-1122
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ wheels/
1313
.ruff_cache
1414
.pytest_cache
1515
.vscode
16-
.DS_Store
16+
.DS_Store
17+
18+
19+
# sqlite
20+
*.db
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import locale
2+
import os
3+
from uuid import uuid4
4+
5+
from dotenv import load_dotenv
6+
from pydantic import BaseModel
7+
from pydantic_ai import Agent, RunContext
8+
from pydantic_ai.models.gemini import GeminiModel
9+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
10+
11+
from flare_ai_kit import FlareAIKit
12+
from flare_ai_kit.a2a import A2AServer
13+
from flare_ai_kit.a2a.schemas import (
14+
AgentCapabilities,
15+
AgentCard,
16+
AgentProvider,
17+
AgentSkill,
18+
Message,
19+
SendMessageRequest,
20+
SendMessageResponse,
21+
TextPart,
22+
)
23+
from flare_ai_kit.a2a.task_management import TaskManager
24+
25+
load_dotenv()
26+
27+
28+
class PriceRequest(BaseModel):
29+
"""Model for price requests."""
30+
31+
symbol: str
32+
base_currency: str | None = "USD"
33+
34+
35+
class PriceResponse(BaseModel):
36+
"""Model for price responses."""
37+
38+
symbol: str
39+
price: float
40+
formated_price: str
41+
currency: str
42+
timestamp: str
43+
44+
45+
class AgentDependencies(BaseModel):
46+
"""Dependencies for the agent."""
47+
48+
flare_kit: FlareAIKit
49+
50+
model_config = {
51+
"arbitrary_types_allowed": True,
52+
}
53+
54+
55+
model = GeminiModel(
56+
os.getenv("AGENT__GEMINI_MODEL", "gemini-2.5-flash"),
57+
provider=GoogleGLAProvider(api_key=os.getenv("AGENT__GEMINI_API_KEY")),
58+
)
59+
60+
price_agent = Agent(
61+
model,
62+
deps_type=AgentDependencies,
63+
retries=2,
64+
system_prompt=(
65+
"You are a crypto price assistant. "
66+
"You help users get current prices for cryptocurrency pairs. "
67+
"When users ask for prices, "
68+
"identify the cryptocurrency symbols they want "
69+
"and call the appropriate tools. "
70+
"Common symbols include FLR, BTC, ETH, ADA, etc. "
71+
"Default to USD pairs unless specified otherwise. "
72+
"Do also display the formated price in "
73+
"parenthesis for easier readability. "
74+
"An example: The current price of BTC "
75+
"is 111070.88 USD ($111,070.88). "
76+
"Ensure you add the forward slash "
77+
"for the pair between the symbols."
78+
),
79+
)
80+
81+
82+
def format_price(price: float, target_locale: str = "en_US.UTF-8") -> str:
83+
"""Returns price formated in currency: defaults to USD."""
84+
locale.setlocale(locale.LC_ALL, target_locale)
85+
return locale.currency(price, grouping=True)
86+
87+
88+
@price_agent.tool
89+
async def get_crypto_price(
90+
ctx: RunContext[AgentDependencies], symbol: str
91+
) -> PriceResponse:
92+
"""Get the latest price for a cryptocurrency pair."""
93+
try:
94+
ftso = await ctx.deps.flare_kit.ftso
95+
96+
price = await ftso.get_latest_price(symbol)
97+
98+
return PriceResponse(
99+
symbol=symbol,
100+
price=price,
101+
formated_price=format_price(price),
102+
currency="USD",
103+
timestamp="now",
104+
)
105+
except Exception as e:
106+
msg = f"Could not get price for {symbol}: {e!s}"
107+
raise ValueError(msg) from e
108+
109+
110+
@price_agent.tool
111+
async def get_multiple_prices(
112+
ctx: RunContext[AgentDependencies], symbols: list[str]
113+
) -> list[PriceResponse]:
114+
"""Get prices for multiple cryptocurrency pairs."""
115+
try:
116+
# Use FlareAIKit to get the FTSO client
117+
ftso = await ctx.deps.flare_kit.ftso
118+
prices_data = await ftso.get_latest_prices(symbols)
119+
120+
results: list[PriceResponse] = []
121+
for i, symbol in enumerate(symbols):
122+
price = prices_data[i]
123+
results.append(
124+
PriceResponse(
125+
symbol=symbol,
126+
price=price,
127+
formated_price=format_price(price),
128+
currency="USD",
129+
timestamp="now",
130+
)
131+
)
132+
except Exception as e:
133+
msg = f"Could not get prices for {symbols}: {e!s}"
134+
raise ValueError(msg) from e
135+
else:
136+
return results
137+
138+
139+
task_manager = TaskManager()
140+
141+
142+
async def handle_send_message(request_body: SendMessageRequest) -> SendMessageResponse:
143+
"""Message send handler."""
144+
try:
145+
user_message = ""
146+
for part in request_body.params.message.parts:
147+
if part.kind == "text":
148+
user_message += part.text
149+
150+
kit = FlareAIKit()
151+
deps = AgentDependencies(flare_kit=kit)
152+
153+
result = await price_agent.run(user_message, deps=deps)
154+
155+
response_text = str(result.output)
156+
157+
return SendMessageResponse(
158+
result=Message(
159+
messageId=uuid4().hex,
160+
role="agent",
161+
parts=[TextPart(text=response_text)],
162+
)
163+
)
164+
except Exception as e:
165+
error_message = f"I apologize, but I encountered an error: {e!s}"
166+
return SendMessageResponse(
167+
result=Message(
168+
messageId=uuid4().hex,
169+
role="agent",
170+
parts=[TextPart(text=error_message)],
171+
)
172+
)
173+
174+
175+
if __name__ == "__main__":
176+
protocol = "http"
177+
host = "localhost"
178+
port = 4500
179+
180+
base_url = f"{protocol}://{host}:{port}"
181+
182+
card = AgentCard(
183+
name="FTSO Agent",
184+
version="0.1.0",
185+
url=base_url,
186+
description="An agent that gets prices",
187+
provider=AgentProvider(
188+
organization="Flare Foundation", url="https://flare.network"
189+
),
190+
capabilities=AgentCapabilities(streaming=False, pushNotifications=False),
191+
skills=[
192+
AgentSkill(
193+
id="list_prices",
194+
name="List Prices",
195+
tags=["BTC", "FLR"],
196+
description="List prices of crypto pairs",
197+
examples=["List prices for FLR/USD and BTC/USD"],
198+
inputModes=["text/plain"],
199+
outputModes=["text/plain"],
200+
)
201+
],
202+
)
203+
204+
a2a_server = A2AServer(card, host=host, port=port)
205+
a2a_server.service.add_handler(SendMessageRequest, handle_send_message)
206+
a2a_server.run()

0 commit comments

Comments
 (0)