1- from collections .abc import AsyncIterator
1+ from collections .abc import AsyncIterator , AsyncGenerator
22from typing import Any
33
44from a2a .client .client import (
55 Client ,
66 ClientCallContext ,
77 ClientConfig ,
8- ClientEvent ,
98 Consumer ,
9+ ClientEvent ,
1010)
1111from a2a .client .client_task_manager import ClientTaskManager
1212from a2a .client .errors import A2AClientInvalidStateError
1313from a2a .client .middleware import ClientCallInterceptor
1414from a2a .client .transports .base import ClientTransport
15- from a2a .types import (
15+ from a2a .types . a2a_pb2 import (
1616 AgentCard ,
1717 GetTaskPushNotificationConfigParams ,
1818 Message ,
19- MessageSendConfiguration ,
20- MessageSendParams ,
19+ SendMessageConfiguration ,
20+ SendMessageRequest ,
2121 Task ,
2222 TaskArtifactUpdateEvent ,
23- TaskIdParams ,
23+ SubscribeToTaskRequest ,
24+ CancelTaskRequest ,
2425 TaskPushNotificationConfig ,
25- TaskQueryParams ,
26+ GetTaskRequest ,
2627 TaskStatusUpdateEvent ,
28+ StreamResponse ,
2729)
2830
2931
@@ -50,7 +52,7 @@ async def send_message(
5052 context : ClientCallContext | None = None ,
5153 request_metadata : dict [str , Any ] | None = None ,
5254 extensions : list [str ] | None = None ,
53- ) -> AsyncIterator [ClientEvent | Message ]:
55+ ) -> AsyncIterator [StreamResponse ]:
5456 """Sends a message to the agent.
5557
5658 This method handles both streaming and non-streaming (polling) interactions
@@ -64,9 +66,9 @@ async def send_message(
6466 extensions: List of extensions to be activated.
6567
6668 Yields:
67- An async iterator of `ClientEvent` or a final `Message` response.
69+ An async iterator of `ClientEvent`
6870 """
69- config = MessageSendConfiguration (
71+ config = SendMessageConfiguration (
7072 accepted_output_modes = self ._config .accepted_output_modes ,
7173 blocking = not self ._config .polling ,
7274 push_notification_config = (
@@ -75,67 +77,63 @@ async def send_message(
7577 else None
7678 ),
7779 )
78- params = MessageSendParams (
79- message = request , configuration = config , metadata = request_metadata
80+ params = SendMessageRequest (
81+ request = request , configuration = config , metadata = request_metadata
8082 )
8183
8284 if not self ._config .streaming or not self ._card .capabilities .streaming :
8385 response = await self ._transport .send_message (
8486 params , context = context , extensions = extensions
8587 )
86- result = (
87- (response , None ) if isinstance (response , Task ) else response
88- )
89- await self .consume (result , self ._card )
90- yield result
88+
89+ # In non-streaming case we convert to a StreamResponse so that the
90+ # client always sees the same iterator.
91+ stream_response = StreamResponse ()
92+ if response .HasField ("task" ):
93+ stream_response .task = response .task
94+ client_event = (stream_response , response .task )
95+
96+ if response .HasField ("message" ):
97+ stream_response .message = response .message
98+ client_event = (stream_response , None )
99+
100+ await self .consume (client_event , self ._card )
101+ yield client_event
91102 return
92103
93- tracker = ClientTaskManager ()
94104 stream = self ._transport .send_message_streaming (
95105 params , context = context , extensions = extensions
96106 )
107+ async for client_event in self ._process_stream (stream ):
108+ yield client_event
97109
98- first_event = await anext (stream )
99- # The response from a server may be either exactly one Message or a
100- # series of Task updates. Separate out the first message for special
101- # case handling, which allows us to simplify further stream processing.
102- if isinstance (first_event , Message ):
103- await self .consume (first_event , self ._card )
104- yield first_event
105- return
106-
107- yield await self ._process_response (tracker , first_event )
110+ async def _process_stream (self , stream : AsyncIterator [StreamResponse ]) -> AsyncGenerator [ClientEvent ]:
111+ tracker = ClientTaskManager ()
112+ async for stream_response in stream :
113+ await self .consume (stream_response )
108114
109- async for event in stream :
110- yield await self ._process_response (tracker , event )
115+ # When we get a message in the stream then we don't expect any
116+ # further messages so yield and return
117+ if stream_response .HasField ("message" ):
118+ yield stream_response , None
119+ return
111120
112- async def _process_response (
113- self ,
114- tracker : ClientTaskManager ,
115- event : Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ,
116- ) -> ClientEvent :
117- if isinstance (event , Message ):
118- raise A2AClientInvalidStateError (
119- 'received a streamed Message from server after first response; this is not supported'
120- )
121- await tracker .process (event )
122- task = tracker .get_task_or_raise ()
123- update = None if isinstance (event , Task ) else event
124- client_event = (task , update )
125- await self .consume (client_event , self ._card )
126- return client_event
121+ # Otherwise track the task / task update then yield to the client
122+ tracker .process (stream_response )
123+ updated_task = tracker .get_task_or_raise ()
124+ yield stream_response , updated_task
127125
128126 async def get_task (
129127 self ,
130- request : TaskQueryParams ,
128+ request : GetTaskRequest ,
131129 * ,
132130 context : ClientCallContext | None = None ,
133131 extensions : list [str ] | None = None ,
134132 ) -> Task :
135133 """Retrieves the current state and history of a specific task.
136134
137135 Args:
138- request: The `TaskQueryParams ` object specifying the task ID.
136+ request: The `GetTaskRequest ` object specifying the task ID.
139137 context: The client call context.
140138 extensions: List of extensions to be activated.
141139
@@ -148,15 +146,15 @@ async def get_task(
148146
149147 async def cancel_task (
150148 self ,
151- request : TaskIdParams ,
149+ request : CancelTaskRequest ,
152150 * ,
153151 context : ClientCallContext | None = None ,
154152 extensions : list [str ] | None = None ,
155153 ) -> Task :
156154 """Requests the agent to cancel a specific task.
157155
158156 Args:
159- request: The `TaskIdParams ` object specifying the task ID.
157+ request: The `CancelTaskRequest ` object specifying the task ID.
160158 context: The client call context.
161159 extensions: List of extensions to be activated.
162160
@@ -211,7 +209,7 @@ async def get_task_callback(
211209
212210 async def resubscribe (
213211 self ,
214- request : TaskIdParams ,
212+ request : SubscribeToTaskRequest ,
215213 * ,
216214 context : ClientCallContext | None = None ,
217215 extensions : list [str ] | None = None ,
@@ -240,10 +238,11 @@ async def resubscribe(
240238 # Note: resubscribe can only be called on an existing task. As such,
241239 # we should never see Message updates, despite the typing of the service
242240 # definition indicating it may be possible.
243- async for event in self ._transport .resubscribe (
241+ stream = self ._transport .resubscribe (
244242 request , context = context , extensions = extensions
245- ):
246- yield await self ._process_response (tracker , event )
243+ )
244+ async for client_event in self ._process_stream (stream ):
245+ yield client_event
247246
248247 async def get_card (
249248 self ,
0 commit comments