1818)
1919from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
2020from a2a .client .transports .base import ClientTransport
21- from a2a .extensions . common import HTTP_EXTENSION_HEADER
21+ from a2a .client . transports . utils import get_http_args , update_extension_header
2222from a2a .types import (
2323 AgentCard ,
2424 CancelTaskRequest ,
@@ -83,25 +83,6 @@ def __init__(
8383 )
8484 self .extensions = extensions
8585
86- def _update_extension_header (
87- self , http_kwargs : dict [str , Any ]
88- ) -> dict [str , Any ]:
89- if not self .extensions :
90- return http_kwargs
91-
92- headers = http_kwargs .setdefault ('headers' , {})
93- existing_extensions_str = headers .get (HTTP_EXTENSION_HEADER , '' )
94-
95- existing_extensions = [
96- e .strip () for e in existing_extensions_str .split (',' ) if e .strip ()
97- ]
98-
99- all_extensions = set (self .extensions )
100- all_extensions .update (existing_extensions )
101-
102- headers [HTTP_EXTENSION_HEADER ] = ',' .join (list (all_extensions ))
103- return http_kwargs
104-
10586 async def _apply_interceptors (
10687 self ,
10788 method_name : str ,
@@ -125,11 +106,6 @@ async def _apply_interceptors(
125106 )
126107 return final_request_payload , final_http_kwargs
127108
128- def _get_http_args (
129- self , context : ClientCallContext | None
130- ) -> dict [str , Any ] | None :
131- return context .state .get ('http_kwargs' ) if context else None
132-
133109 async def send_message (
134110 self ,
135111 request : MessageSendParams ,
@@ -141,10 +117,12 @@ async def send_message(
141117 payload , modified_kwargs = await self ._apply_interceptors (
142118 'message/send' ,
143119 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
144- self . _get_http_args (context ),
120+ get_http_args (context ),
145121 context ,
146122 )
147- modified_kwargs = self ._update_extension_header (modified_kwargs )
123+ modified_kwargs = update_extension_header (
124+ modified_kwargs , self .extensions
125+ )
148126 response_data = await self ._send_request (payload , modified_kwargs )
149127 response = SendMessageResponse .model_validate (response_data )
150128 if isinstance (response .root , JSONRPCErrorResponse ):
@@ -166,11 +144,13 @@ async def send_message_streaming(
166144 payload , modified_kwargs = await self ._apply_interceptors (
167145 'message/stream' ,
168146 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
169- self . _get_http_args (context ),
147+ get_http_args (context ),
170148 context ,
171149 )
172150
173- modified_kwargs = self ._update_extension_header (modified_kwargs )
151+ modified_kwargs = update_extension_header (
152+ modified_kwargs , self .extensions
153+ )
174154 modified_kwargs .setdefault (
175155 'timeout' , self .httpx_client .timeout .as_dict ().get ('read' , None )
176156 )
@@ -237,7 +217,7 @@ async def get_task(
237217 payload , modified_kwargs = await self ._apply_interceptors (
238218 'tasks/get' ,
239219 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
240- self . _get_http_args (context ),
220+ get_http_args (context ),
241221 context ,
242222 )
243223 response_data = await self ._send_request (payload , modified_kwargs )
@@ -257,7 +237,7 @@ async def cancel_task(
257237 payload , modified_kwargs = await self ._apply_interceptors (
258238 'tasks/cancel' ,
259239 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
260- self . _get_http_args (context ),
240+ get_http_args (context ),
261241 context ,
262242 )
263243 response_data = await self ._send_request (payload , modified_kwargs )
@@ -279,7 +259,7 @@ async def set_task_callback(
279259 payload , modified_kwargs = await self ._apply_interceptors (
280260 'tasks/pushNotificationConfig/set' ,
281261 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
282- self . _get_http_args (context ),
262+ get_http_args (context ),
283263 context ,
284264 )
285265 response_data = await self ._send_request (payload , modified_kwargs )
@@ -303,7 +283,7 @@ async def get_task_callback(
303283 payload , modified_kwargs = await self ._apply_interceptors (
304284 'tasks/pushNotificationConfig/get' ,
305285 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
306- self . _get_http_args (context ),
286+ get_http_args (context ),
307287 context ,
308288 )
309289 response_data = await self ._send_request (payload , modified_kwargs )
@@ -327,7 +307,7 @@ async def resubscribe(
327307 payload , modified_kwargs = await self ._apply_interceptors (
328308 'tasks/resubscribe' ,
329309 rpc_request .model_dump (mode = 'json' , exclude_none = True ),
330- self . _get_http_args (context ),
310+ get_http_args (context ),
331311 context ,
332312 )
333313
@@ -369,7 +349,7 @@ async def get_card(
369349 if not card :
370350 resolver = A2ACardResolver (self .httpx_client , self .url )
371351 card = await resolver .get_agent_card (
372- http_kwargs = self . _get_http_args (context )
352+ http_kwargs = get_http_args (context )
373353 )
374354 self ._needs_extended_card = (
375355 card .supports_authenticated_extended_card
@@ -383,7 +363,7 @@ async def get_card(
383363 payload , modified_kwargs = await self ._apply_interceptors (
384364 request .method ,
385365 request .model_dump (mode = 'json' , exclude_none = True ),
386- self . _get_http_args (context ),
366+ get_http_args (context ),
387367 context ,
388368 )
389369
0 commit comments