Skip to content

Commit 4a423ef

Browse files
committed
Change the order of update_extension_header and _apply_interceptors function calls inside rest and jsonrpc methods
1 parent 80be4bf commit 4a423ef

File tree

4 files changed

+93
-86
lines changed

4 files changed

+93
-86
lines changed

src/a2a/client/transports/jsonrpc.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,15 @@ async def send_message(
120120
) -> Task | Message:
121121
"""Sends a non-streaming message request to the agent."""
122122
rpc_request = SendMessageRequest(params=request, id=str(uuid4()))
123+
modified_kwargs = update_extension_header(
124+
self._get_http_args(context),
125+
extensions if extensions is not None else self.extensions,
126+
)
123127
payload, modified_kwargs = await self._apply_interceptors(
124128
'message/send',
125129
rpc_request.model_dump(mode='json', exclude_none=True),
126-
self._get_http_args(context),
127-
context,
128-
)
129-
modified_kwargs = update_extension_header(
130130
modified_kwargs,
131-
extensions if extensions is not None else self.extensions,
131+
context,
132132
)
133133
response_data = await self._send_request(payload, modified_kwargs)
134134
response = SendMessageResponse.model_validate(response_data)
@@ -149,16 +149,15 @@ async def send_message_streaming(
149149
rpc_request = SendStreamingMessageRequest(
150150
params=request, id=str(uuid4())
151151
)
152+
modified_kwargs = update_extension_header(
153+
self._get_http_args(context),
154+
extensions if extensions is not None else self.extensions,
155+
)
152156
payload, modified_kwargs = await self._apply_interceptors(
153157
'message/stream',
154158
rpc_request.model_dump(mode='json', exclude_none=True),
155-
self._get_http_args(context),
156-
context,
157-
)
158-
159-
modified_kwargs = update_extension_header(
160159
modified_kwargs,
161-
extensions if extensions is not None else self.extensions,
160+
context,
162161
)
163162
modified_kwargs.setdefault(
164163
'timeout', self.httpx_client.timeout.as_dict().get('read', None)
@@ -224,15 +223,15 @@ async def get_task(
224223
) -> Task:
225224
"""Retrieves the current state and history of a specific task."""
226225
rpc_request = GetTaskRequest(params=request, id=str(uuid4()))
226+
modified_kwargs = update_extension_header(
227+
self._get_http_args(context),
228+
extensions if extensions is not None else self.extensions,
229+
)
227230
payload, modified_kwargs = await self._apply_interceptors(
228231
'tasks/get',
229232
rpc_request.model_dump(mode='json', exclude_none=True),
230-
self._get_http_args(context),
231-
context,
232-
)
233-
modified_kwargs = update_extension_header(
234233
modified_kwargs,
235-
extensions if extensions is not None else self.extensions,
234+
context,
236235
)
237236
response_data = await self._send_request(payload, modified_kwargs)
238237
response = GetTaskResponse.model_validate(response_data)
@@ -249,15 +248,15 @@ async def cancel_task(
249248
) -> Task:
250249
"""Requests the agent to cancel a specific task."""
251250
rpc_request = CancelTaskRequest(params=request, id=str(uuid4()))
251+
modified_kwargs = update_extension_header(
252+
self._get_http_args(context),
253+
extensions if extensions is not None else self.extensions,
254+
)
252255
payload, modified_kwargs = await self._apply_interceptors(
253256
'tasks/cancel',
254257
rpc_request.model_dump(mode='json', exclude_none=True),
255-
self._get_http_args(context),
256-
context,
257-
)
258-
modified_kwargs = update_extension_header(
259258
modified_kwargs,
260-
extensions if extensions is not None else self.extensions,
259+
context,
261260
)
262261
response_data = await self._send_request(payload, modified_kwargs)
263262
response = CancelTaskResponse.model_validate(response_data)
@@ -276,15 +275,15 @@ async def set_task_callback(
276275
rpc_request = SetTaskPushNotificationConfigRequest(
277276
params=request, id=str(uuid4())
278277
)
278+
modified_kwargs = update_extension_header(
279+
self._get_http_args(context),
280+
extensions if extensions is not None else self.extensions,
281+
)
279282
payload, modified_kwargs = await self._apply_interceptors(
280283
'tasks/pushNotificationConfig/set',
281284
rpc_request.model_dump(mode='json', exclude_none=True),
282-
self._get_http_args(context),
283-
context,
284-
)
285-
modified_kwargs = update_extension_header(
286285
modified_kwargs,
287-
extensions if extensions is not None else self.extensions,
286+
context,
288287
)
289288
response_data = await self._send_request(payload, modified_kwargs)
290289
response = SetTaskPushNotificationConfigResponse.model_validate(
@@ -305,15 +304,15 @@ async def get_task_callback(
305304
rpc_request = GetTaskPushNotificationConfigRequest(
306305
params=request, id=str(uuid4())
307306
)
307+
modified_kwargs = update_extension_header(
308+
self._get_http_args(context),
309+
extensions if extensions is not None else self.extensions,
310+
)
308311
payload, modified_kwargs = await self._apply_interceptors(
309312
'tasks/pushNotificationConfig/get',
310313
rpc_request.model_dump(mode='json', exclude_none=True),
311-
self._get_http_args(context),
312-
context,
313-
)
314-
modified_kwargs = update_extension_header(
315314
modified_kwargs,
316-
extensions if extensions is not None else self.extensions,
315+
context,
317316
)
318317
response_data = await self._send_request(payload, modified_kwargs)
319318
response = GetTaskPushNotificationConfigResponse.model_validate(
@@ -334,15 +333,15 @@ async def resubscribe(
334333
]:
335334
"""Reconnects to get task updates."""
336335
rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4()))
336+
modified_kwargs = update_extension_header(
337+
self._get_http_args(context),
338+
extensions if extensions is not None else self.extensions,
339+
)
337340
payload, modified_kwargs = await self._apply_interceptors(
338341
'tasks/resubscribe',
339342
rpc_request.model_dump(mode='json', exclude_none=True),
340-
self._get_http_args(context),
341-
context,
342-
)
343-
modified_kwargs = update_extension_header(
344343
modified_kwargs,
345-
extensions if extensions is not None else self.extensions,
344+
context,
346345
)
347346
modified_kwargs.setdefault('timeout', None)
348347

@@ -394,17 +393,16 @@ async def get_card(
394393
return card
395394

396395
request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))
396+
modified_kwargs = update_extension_header(
397+
self._get_http_args(context),
398+
extensions if extensions is not None else self.extensions,
399+
)
397400
payload, modified_kwargs = await self._apply_interceptors(
398401
request.method,
399402
request.model_dump(mode='json', exclude_none=True),
400-
self._get_http_args(context),
401-
context,
402-
)
403-
modified_kwargs = update_extension_header(
404403
modified_kwargs,
405-
extensions if extensions is not None else self.extensions,
404+
context,
406405
)
407-
408406
response_data = await self._send_request(
409407
payload,
410408
modified_kwargs,

src/a2a/client/transports/rest.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ async def _prepare_send_message(
9999
),
100100
)
101101
payload = MessageToDict(pb)
102-
payload, modified_kwargs = await self._apply_interceptors(
103-
payload,
102+
modified_kwargs = update_extension_header(
104103
self._get_http_args(context),
105-
context,
104+
extensions if extensions is not None else self.extensions,
106105
)
107-
modified_kwargs = update_extension_header(
106+
payload, modified_kwargs = await self._apply_interceptors(
107+
payload,
108108
modified_kwargs,
109-
extensions if extensions is not None else self.extensions,
109+
context,
110110
)
111111
return payload, modified_kwargs
112112

@@ -219,14 +219,14 @@ async def get_task(
219219
extensions: list[str] | None = None,
220220
) -> Task:
221221
"""Retrieves the current state and history of a specific task."""
222-
_payload, modified_kwargs = await self._apply_interceptors(
223-
request.model_dump(mode='json', exclude_none=True),
222+
modified_kwargs = update_extension_header(
224223
self._get_http_args(context),
225-
context,
224+
extensions if extensions is not None else self.extensions,
226225
)
227-
modified_kwargs = update_extension_header(
226+
_payload, modified_kwargs = await self._apply_interceptors(
227+
request.model_dump(mode='json', exclude_none=True),
228228
modified_kwargs,
229-
extensions if extensions is not None else self.extensions,
229+
context,
230230
)
231231
response_data = await self._send_get_request(
232232
f'/v1/tasks/{request.id}',
@@ -249,14 +249,14 @@ async def cancel_task(
249249
"""Requests the agent to cancel a specific task."""
250250
pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
251251
payload = MessageToDict(pb)
252-
payload, modified_kwargs = await self._apply_interceptors(
253-
payload,
252+
modified_kwargs = update_extension_header(
254253
self._get_http_args(context),
255-
context,
254+
extensions if extensions is not None else self.extensions,
256255
)
257-
modified_kwargs = update_extension_header(
256+
payload, modified_kwargs = await self._apply_interceptors(
257+
payload,
258258
modified_kwargs,
259-
extensions if extensions is not None else self.extensions,
259+
context,
260260
)
261261
response_data = await self._send_post_request(
262262
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
@@ -279,13 +279,13 @@ async def set_task_callback(
279279
config=proto_utils.ToProto.task_push_notification_config(request),
280280
)
281281
payload = MessageToDict(pb)
282-
payload, modified_kwargs = await self._apply_interceptors(
283-
payload, self._get_http_args(context), context
284-
)
285282
modified_kwargs = update_extension_header(
286-
modified_kwargs,
283+
self._get_http_args(context),
287284
extensions if extensions is not None else self.extensions,
288285
)
286+
payload, modified_kwargs = await self._apply_interceptors(
287+
payload, modified_kwargs, context
288+
)
289289
response_data = await self._send_post_request(
290290
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
291291
payload,
@@ -307,14 +307,14 @@ async def get_task_callback(
307307
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
308308
)
309309
payload = MessageToDict(pb)
310-
payload, modified_kwargs = await self._apply_interceptors(
311-
payload,
310+
modified_kwargs = update_extension_header(
312311
self._get_http_args(context),
313-
context,
312+
extensions if extensions is not None else self.extensions,
314313
)
315-
modified_kwargs = update_extension_header(
314+
payload, modified_kwargs = await self._apply_interceptors(
315+
payload,
316316
modified_kwargs,
317-
extensions if extensions is not None else self.extensions,
317+
context,
318318
)
319319
response_data = await self._send_get_request(
320320
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
@@ -335,12 +335,11 @@ async def resubscribe(
335335
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
336336
]:
337337
"""Reconnects to get task updates."""
338-
http_kwargs = self._get_http_args(context) or {}
339-
http_kwargs.setdefault('timeout', None)
340338
modified_kwargs = update_extension_header(
341-
http_kwargs,
339+
self._get_http_args(context),
342340
extensions if extensions is not None else self.extensions,
343341
)
342+
modified_kwargs.setdefault('timeout', None)
344343

345344
async with aconnect_sse(
346345
self.httpx_client,
@@ -385,14 +384,14 @@ async def get_card(
385384
if not self._needs_extended_card:
386385
return card
387386

388-
_, modified_kwargs = await self._apply_interceptors(
389-
{},
387+
modified_kwargs = update_extension_header(
390388
self._get_http_args(context),
391-
context,
389+
extensions if extensions is not None else self.extensions,
392390
)
393-
modified_kwargs = update_extension_header(
391+
_, modified_kwargs = await self._apply_interceptors(
392+
{},
394393
modified_kwargs,
395-
extensions if extensions is not None else self.extensions,
394+
context,
396395
)
397396
response_data = await self._send_get_request(
398397
'/v1/card', {}, modified_kwargs

src/a2a/extensions/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
3030

3131

3232
def update_extension_header(
33-
http_kwargs: dict[str, Any],
33+
http_kwargs: dict[str, Any] | None,
3434
extensions: list[str] | None,
3535
) -> dict[str, Any]:
3636
"""Update the X-A2A-Extensions header with active extensions."""
37+
http_kwargs = http_kwargs or {}
3738
if extensions is not None:
3839
headers = http_kwargs.setdefault('headers', {})
3940
headers[HTTP_EXTENSION_HEADER] = ','.join(extensions)

tests/extensions/test_common.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,20 @@ def test_find_extension_by_uri_no_extensions():
7474
), # Case 1: New extensions provided, empty header.
7575
(
7676
None, # extensions
77-
'ext1, ext2', # existing_header
77+
'ext1, ext2', # header
7878
{
7979
'ext1',
8080
'ext2',
8181
}, # expected_extensions
8282
), # Case 2: Extensions is None, existing header extensions.
8383
(
8484
[], # extensions
85-
'ext1', # existing_header
85+
'ext1', # header
8686
{}, # expected_extensions
8787
), # Case 3: New extensions is empty list, existing header extensions.
8888
(
8989
['ext1', 'ext2'], # extensions
90-
'ext3', # existing_header
90+
'ext3', # header
9191
{
9292
'ext1',
9393
'ext2',
@@ -121,17 +121,26 @@ def test_update_extension_header_with_other_headers():
121121
assert headers['X_Other'] == 'Test'
122122

123123

124-
def test_update_extension_header_with_other_headers_extensions_none():
125-
http_kwargs = {'headers': {'X_Other': 'Test'}}
126-
result_kwargs = update_extension_header(http_kwargs, None)
127-
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
128-
assert result_kwargs['headers']['X_Other'] == 'Test'
129-
130-
131-
def test_update_extension_header_headers_not_in_kwargs():
124+
@pytest.mark.parametrize(
125+
'http_kwargs',
126+
[
127+
None,
128+
{},
129+
],
130+
)
131+
def test_update_extension_header_headers_not_in_kwargs(
132+
http_kwargs: dict[str, str] | None,
133+
):
132134
extensions = ['ext']
133135
http_kwargs = {}
134136
result_kwargs = update_extension_header(http_kwargs, extensions)
135137
headers = result_kwargs.get('headers', {})
136138
assert HTTP_EXTENSION_HEADER in headers
137139
assert headers[HTTP_EXTENSION_HEADER] == 'ext'
140+
141+
142+
def test_update_extension_header_with_other_headers_extensions_none():
143+
http_kwargs = {'headers': {'X_Other': 'Test'}}
144+
result_kwargs = update_extension_header(http_kwargs, None)
145+
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
146+
assert result_kwargs['headers']['X_Other'] == 'Test'

0 commit comments

Comments
 (0)