1- from collections .abc import AsyncIterator
1+ from collections .abc import AsyncIterator , AsyncGenerator
22from typing import Any
33
44from a2a .client .client import (
55 Client ,
66 ClientCallContext ,
77 ClientConfig ,
8- ClientEvent ,
98 Consumer ,
9+ ClientEvent ,
1010)
1111from a2a .client .client_task_manager import ClientTaskManager
1212from a2a .client .errors import A2AClientInvalidStateError
1313from a2a .client .middleware import ClientCallInterceptor
1414from a2a .client .transports .base import ClientTransport
15- from a2a .types import (
15+ from a2a .types . a2a_pb2 import (
1616 AgentCard ,
17- GetTaskPushNotificationConfigParams ,
1817 Message ,
19- MessageSendConfiguration ,
20- MessageSendParams ,
18+ SendMessageConfiguration ,
19+ SendMessageRequest ,
2120 Task ,
2221 TaskArtifactUpdateEvent ,
23- TaskIdParams ,
22+ SubscribeToTaskRequest ,
23+ CancelTaskRequest ,
2424 TaskPushNotificationConfig ,
25- TaskQueryParams ,
25+ GetTaskRequest ,
2626 TaskStatusUpdateEvent ,
27+ StreamResponse ,
28+ SetTaskPushNotificationConfigRequest ,
29+ GetExtendedAgentCardRequest ,
30+ GetTaskPushNotificationConfigRequest ,
2731)
2832
2933
@@ -50,7 +54,7 @@ async def send_message(
5054 context : ClientCallContext | None = None ,
5155 request_metadata : dict [str , Any ] | None = None ,
5256 extensions : list [str ] | None = None ,
53- ) -> AsyncIterator [ClientEvent | Message ]:
57+ ) -> AsyncIterator [ClientEvent ]:
5458 """Sends a message to the agent.
5559
5660 This method handles both streaming and non-streaming (polling) interactions
@@ -64,9 +68,9 @@ async def send_message(
6468 extensions: List of extensions to be activated.
6569
6670 Yields:
67- An async iterator of `ClientEvent` or a final `Message` response.
71+ An async iterator of `ClientEvent`
6872 """
69- config = MessageSendConfiguration (
73+ config = SendMessageConfiguration (
7074 accepted_output_modes = self ._config .accepted_output_modes ,
7175 blocking = not self ._config .polling ,
7276 push_notification_config = (
@@ -75,67 +79,67 @@ async def send_message(
7579 else None
7680 ),
7781 )
78- params = MessageSendParams (
79- message = request , configuration = config , metadata = request_metadata
82+ sendMessageRequest = SendMessageRequest (
83+ request = request , configuration = config , metadata = request_metadata
8084 )
8185
8286 if not self ._config .streaming or not self ._card .capabilities .streaming :
8387 response = await self ._transport .send_message (
84- params , context = context , extensions = extensions
85- )
86- result = (
87- (response , None ) if isinstance (response , Task ) else response
88+ sendMessageRequest , context = context , extensions = extensions
8889 )
89- await self .consume (result , self ._card )
90- yield result
91- return
9290
93- tracker = ClientTaskManager ()
94- stream = self ._transport .send_message_streaming (
95- params , context = context , extensions = extensions
96- )
91+ # In non-streaming case we convert to a StreamResponse so that the
92+ # client always sees the same iterator.
93+ stream_response = StreamResponse ()
94+ client_event : ClientEvent
95+ if response .HasField ("task" ):
96+ stream_response .task = response .task
97+ client_event = (stream_response , response .task )
9798
98- first_event = await anext (stream )
99- # The response from a server may be either exactly one Message or a
100- # series of Task updates. Separate out the first message for special
101- # case handling, which allows us to simplify further stream processing.
102- if isinstance (first_event , Message ):
103- await self .consume (first_event , self ._card )
104- yield first_event
105- return
99+ elif response .HasField ("message" ):
100+ stream_response .msg = response .msg
101+ client_event = (stream_response , None )
106102
107- yield await self ._process_response (tracker , first_event )
103+ await self .consume (client_event , self ._card )
104+ yield client_event
105+ return
108106
109- async for event in stream :
110- yield await self ._process_response (tracker , event )
107+ stream = self ._transport .send_message_streaming (
108+ sendMessageRequest , context = context , extensions = extensions
109+ )
110+ async for client_event in self ._process_stream (stream ):
111+ yield client_event
111112
112- async def _process_response (
113- self ,
114- tracker : ClientTaskManager ,
115- event : Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ,
116- ) -> ClientEvent :
117- if isinstance (event , Message ):
118- raise A2AClientInvalidStateError (
119- 'received a streamed Message from server after first response; this is not supported'
120- )
121- await tracker .process (event )
122- task = tracker .get_task_or_raise ()
123- update = None if isinstance (event , Task ) else event
124- client_event = (task , update )
125- await self .consume (client_event , self ._card )
126- return client_event
113+ async def _process_stream (self , stream : AsyncIterator [StreamResponse ]) -> AsyncGenerator [ClientEvent ]:
114+ tracker = ClientTaskManager ()
115+ async for stream_response in stream :
116+ client_event : ClientEvent
117+ # When we get a message in the stream then we don't expect any
118+ # further messages so yield and return
119+ if stream_response .HasField ("message" ):
120+ client_event = (stream_response , None )
121+ await self .consume (client_event , self ._card )
122+ yield client_event
123+ return
124+
125+ # Otherwise track the task / task update then yield to the client
126+ await tracker .process (stream_response )
127+ updated_task = tracker .get_task_or_raise ()
128+ client_event = (stream_response , updated_task )
129+ await self .consume (client_event , self ._card )
130+ yield client_event
127131
128132 async def get_task (
129133 self ,
130- request : TaskQueryParams ,
134+ request : GetTaskRequest ,
131135 * ,
132136 context : ClientCallContext | None = None ,
133137 extensions : list [str ] | None = None ,
134138 ) -> Task :
135139 """Retrieves the current state and history of a specific task.
136140
137141 Args:
138- request: The `TaskQueryParams ` object specifying the task ID.
142+ request: The `GetTaskRequest ` object specifying the task ID.
139143 context: The client call context.
140144 extensions: List of extensions to be activated.
141145
@@ -148,15 +152,15 @@ async def get_task(
148152
149153 async def cancel_task (
150154 self ,
151- request : TaskIdParams ,
155+ request : CancelTaskRequest ,
152156 * ,
153157 context : ClientCallContext | None = None ,
154158 extensions : list [str ] | None = None ,
155159 ) -> Task :
156160 """Requests the agent to cancel a specific task.
157161
158162 Args:
159- request: The `TaskIdParams ` object specifying the task ID.
163+ request: The `CancelTaskRequest ` object specifying the task ID.
160164 context: The client call context.
161165 extensions: List of extensions to be activated.
162166
@@ -169,7 +173,7 @@ async def cancel_task(
169173
170174 async def set_task_callback (
171175 self ,
172- request : TaskPushNotificationConfig ,
176+ request : SetTaskPushNotificationConfigRequest ,
173177 * ,
174178 context : ClientCallContext | None = None ,
175179 extensions : list [str ] | None = None ,
@@ -190,7 +194,7 @@ async def set_task_callback(
190194
191195 async def get_task_callback (
192196 self ,
193- request : GetTaskPushNotificationConfigParams ,
197+ request : GetTaskPushNotificationConfigRequest ,
194198 * ,
195199 context : ClientCallContext | None = None ,
196200 extensions : list [str ] | None = None ,
@@ -209,9 +213,9 @@ async def get_task_callback(
209213 request , context = context , extensions = extensions
210214 )
211215
212- async def resubscribe (
216+ async def subscribe (
213217 self ,
214- request : TaskIdParams ,
218+ request : SubscribeToTaskRequest ,
215219 * ,
216220 context : ClientCallContext | None = None ,
217221 extensions : list [str ] | None = None ,
@@ -240,12 +244,13 @@ async def resubscribe(
240244 # Note: resubscribe can only be called on an existing task. As such,
241245 # we should never see Message updates, despite the typing of the service
242246 # definition indicating it may be possible.
243- async for event in self ._transport .resubscribe (
247+ stream = self ._transport .subscribe (
244248 request , context = context , extensions = extensions
245- ):
246- yield await self ._process_response (tracker , event )
249+ )
250+ async for client_event in self ._process_stream (stream ):
251+ yield client_event
247252
248- async def get_card (
253+ async def get_extended_agent_card (
249254 self ,
250255 * ,
251256 context : ClientCallContext | None = None ,
@@ -263,7 +268,7 @@ async def get_card(
263268 Returns:
264269 The `AgentCard` for the agent.
265270 """
266- card = await self ._transport .get_card (
271+ card = await self ._transport .get_extended_agent_card (
267272 context = context , extensions = extensions
268273 )
269274 self ._card = card
0 commit comments