-
Notifications
You must be signed in to change notification settings - Fork 424
Expand file tree
/
Copy pathcli.py
More file actions
135 lines (111 loc) · 4.21 KB
/
cli.py
File metadata and controls
135 lines (111 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import asyncio
import os
import signal
import uuid
from typing import Any
import grpc
import httpx
from a2a.client import A2ACardResolver, ClientConfig, create_client
from a2a.helpers import get_artifact_text, get_message_text
from a2a.helpers.agent_card import display_agent_card
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState
async def _handle_stream(
stream: Any, current_task_id: str | None
) -> str | None:
async for event in stream:
if event.HasField('message'):
print('Message:', get_message_text(event.message, delimiter=' '))
return None
if not current_task_id:
if event.HasField('task'):
current_task_id = event.task.id
print('--- Task Started ---')
print(f'Task [state={TaskState.Name(event.task.status.state)}]')
else:
raise ValueError(f'Unexpected first event: {event}')
if event.HasField('status_update'):
state_name = TaskState.Name(event.status_update.status.state)
message_text = (
': '
+ get_message_text(
event.status_update.status.message, delimiter=' '
)
if event.status_update.status.HasField('message')
else ''
)
print(f'TaskStatusUpdate [state={state_name}]{message_text}')
if state_name in (
'TASK_STATE_COMPLETED',
'TASK_STATE_FAILED',
'TASK_STATE_CANCELED',
'TASK_STATE_REJECTED',
):
current_task_id = None
print('--- Task Finished ---')
elif event.HasField('artifact_update'):
print(
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
get_artifact_text(
event.artifact_update.artifact, delimiter=' '
),
)
return current_task_id
async def main() -> None:
"""Run the A2A terminal client."""
parser = argparse.ArgumentParser(description='A2A Terminal Client')
parser.add_argument(
'--url', default='http://127.0.0.1:41241', help='Agent base URL'
)
parser.add_argument(
'--transport',
default=None,
help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)',
)
args = parser.parse_args()
config = ClientConfig(
grpc_channel_factory=grpc.aio.insecure_channel,
)
if args.transport:
config.supported_protocol_bindings = [args.transport]
print(
f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})'
)
async with httpx.AsyncClient() as httpx_client:
resolver = A2ACardResolver(httpx_client, args.url)
card = await resolver.get_agent_card()
print('\n✓ Agent Card Found:')
display_agent_card(card)
client = await create_client(card, client_config=config)
actual_transport = getattr(client, '_transport', client)
print(f' Picked Transport: {actual_transport.__class__.__name__}')
print('\nConnected! Send a message or type /quit to exit.')
current_task_id = None
current_context_id = str(uuid.uuid4())
while True:
try:
loop = asyncio.get_running_loop()
user_input = await loop.run_in_executor(None, input, 'You: ')
except KeyboardInterrupt:
break
if user_input.lower() in ('/quit', '/exit'):
break
if not user_input.strip():
continue
message = Message(
role=Role.ROLE_USER,
message_id=str(uuid.uuid4()),
parts=[Part(text=user_input)],
task_id=current_task_id,
context_id=current_context_id,
)
request = SendMessageRequest(message=message)
try:
stream = client.send_message(request)
current_task_id = await _handle_stream(stream, current_task_id)
except (httpx.RequestError, grpc.RpcError) as e:
print(f'Error communicating with agent: {e}')
await client.close()
if __name__ == '__main__':
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
asyncio.run(main())