Skip to content

Commit a9aa9ee

Browse files
committed
feat: add support for extensions in Client and BaseClient, update transport methods to handle extensions
1 parent 270d6e7 commit a9aa9ee

File tree

7 files changed

+93
-44
lines changed

7 files changed

+93
-44
lines changed

src/a2a/client/base_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ def __init__(
3636
transport: ClientTransport,
3737
consumers: list[Consumer],
3838
middleware: list[ClientCallInterceptor],
39+
extensions: list[str],
3940
):
40-
super().__init__(consumers, middleware)
41+
super().__init__(consumers, middleware, extensions)
4142
self._card = card
4243
self._config = config
4344
self._transport = transport
45+
if self._extensions:
46+
self._config.extensions = self._extensions
4447

4548
async def send_message(
4649
self,

src/a2a/client/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def __init__(
9393
self,
9494
consumers: list[Consumer] | None = None,
9595
middleware: list[ClientCallInterceptor] | None = None,
96-
# iva todo add optional extensions- it can override value from the config, if it is provided
96+
# iva todo - it can override value from the config, if it is provided
97+
extensions: list[str] | None = None,
9798
):
9899
"""Initializes the client with consumers and middleware.
99100
@@ -105,8 +106,11 @@ def __init__(
105106
middleware = []
106107
if consumers is None:
107108
consumers = []
109+
if extensions is None:
110+
extensions = []
108111
self._consumers = consumers
109112
self._middleware = middleware
113+
self._extensions = extensions
110114

111115
@abstractmethod
112116
async def send_message(

src/a2a/client/client_factory.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ async def connect( # noqa: PLR0913
115115
relative_card_path: str | None = None,
116116
resolver_http_kwargs: dict[str, Any] | None = None,
117117
extra_transports: dict[str, TransportProducer] | None = None,
118+
extensions: list[str] | None = None,
118119
) -> Client:
119120
"""Convenience method for constructing a client.
120121
@@ -168,7 +169,7 @@ async def connect( # noqa: PLR0913
168169
factory = cls(client_config)
169170
for label, generator in (extra_transports or {}).items():
170171
factory.register(label, generator)
171-
return factory.create(card, consumers, interceptors)
172+
return factory.create(card, consumers, interceptors, extensions)
172173

173174
def register(self, label: str, generator: TransportProducer) -> None:
174175
"""Register a new transport producer for a given transport label."""
@@ -179,6 +180,7 @@ def create(
179180
card: AgentCard,
180181
consumers: list[Consumer] | None = None,
181182
interceptors: list[ClientCallInterceptor] | None = None,
183+
extensions: list[str] | None = None,
182184
) -> Client:
183185
"""Create a new `Client` for the provided `AgentCard`.
184186
@@ -228,12 +230,22 @@ def create(
228230
if consumers:
229231
all_consumers.extend(consumers)
230232

233+
all_extensions = self._config.extensions.copy()
234+
if extensions:
235+
all_extensions.extend(extensions)
236+
self._config.extensions = all_extensions
237+
231238
transport = self._registry[transport_protocol](
232239
card, transport_url, self._config, interceptors or []
233240
)
234241

235242
return BaseClient(
236-
card, self._config, transport, all_consumers, interceptors or []
243+
card,
244+
self._config,
245+
transport,
246+
all_consumers,
247+
interceptors or [],
248+
all_extensions,
237249
)
238250

239251

src/a2a/client/transports/jsonrpc.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ async def get_task(
220220
get_http_args(context),
221221
context,
222222
)
223+
modified_kwargs = update_extension_header(
224+
modified_kwargs, self.extensions
225+
)
223226
response_data = await self._send_request(payload, modified_kwargs)
224227
response = GetTaskResponse.model_validate(response_data)
225228
if isinstance(response.root, JSONRPCErrorResponse):
@@ -240,6 +243,9 @@ async def cancel_task(
240243
get_http_args(context),
241244
context,
242245
)
246+
modified_kwargs = update_extension_header(
247+
modified_kwargs, self.extensions
248+
)
243249
response_data = await self._send_request(payload, modified_kwargs)
244250
response = CancelTaskResponse.model_validate(response_data)
245251
if isinstance(response.root, JSONRPCErrorResponse):
@@ -262,6 +268,9 @@ async def set_task_callback(
262268
get_http_args(context),
263269
context,
264270
)
271+
modified_kwargs = update_extension_header(
272+
modified_kwargs, self.extensions
273+
)
265274
response_data = await self._send_request(payload, modified_kwargs)
266275
response = SetTaskPushNotificationConfigResponse.model_validate(
267276
response_data
@@ -286,6 +295,9 @@ async def get_task_callback(
286295
get_http_args(context),
287296
context,
288297
)
298+
modified_kwargs = update_extension_header(
299+
modified_kwargs, self.extensions
300+
)
289301
response_data = await self._send_request(payload, modified_kwargs)
290302
response = GetTaskPushNotificationConfigResponse.model_validate(
291303
response_data
@@ -310,7 +322,9 @@ async def resubscribe(
310322
get_http_args(context),
311323
context,
312324
)
313-
325+
modified_kwargs = update_extension_header(
326+
modified_kwargs, self.extensions
327+
)
314328
modified_kwargs.setdefault('timeout', None)
315329

316330
async with aconnect_sse(
@@ -366,6 +380,9 @@ async def get_card(
366380
get_http_args(context),
367381
context,
368382
)
383+
modified_kwargs = update_extension_header(
384+
modified_kwargs, self.extensions
385+
)
369386

370387
response_data = await self._send_request(
371388
payload,

src/a2a/client/transports/rest.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ async def get_task(
212212
get_http_args(context),
213213
context,
214214
)
215+
modified_kwargs = update_extension_header(
216+
modified_kwargs, self.extensions
217+
)
215218
response_data = await self._send_get_request(
216219
f'/v1/tasks/{request.id}',
217220
{'historyLength': str(request.history_length)}
@@ -237,6 +240,9 @@ async def cancel_task(
237240
get_http_args(context),
238241
context,
239242
)
243+
modified_kwargs = update_extension_header(
244+
modified_kwargs, self.extensions
245+
)
240246
response_data = await self._send_post_request(
241247
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
242248
)
@@ -260,6 +266,9 @@ async def set_task_callback(
260266
payload, modified_kwargs = await self._apply_interceptors(
261267
payload, get_http_args(context), context
262268
)
269+
modified_kwargs = update_extension_header(
270+
modified_kwargs, self.extensions
271+
)
263272
response_data = await self._send_post_request(
264273
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
265274
payload,
@@ -285,6 +294,9 @@ async def get_task_callback(
285294
get_http_args(context),
286295
context,
287296
)
297+
modified_kwargs = update_extension_header(
298+
modified_kwargs, self.extensions
299+
)
288300
response_data = await self._send_get_request(
289301
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
290302
{},
@@ -305,12 +317,13 @@ async def resubscribe(
305317
"""Reconnects to get task updates."""
306318
http_kwargs = get_http_args(context) or {}
307319
http_kwargs.setdefault('timeout', None)
320+
modified_kwargs = update_extension_header(http_kwargs, self.extensions)
308321

309322
async with aconnect_sse(
310323
self.httpx_client,
311324
'GET',
312325
f'{self.url}/v1/tasks/{request.id}:subscribe',
313-
**http_kwargs,
326+
**modified_kwargs,
314327
) as event_source:
315328
try:
316329
async for sse in event_source.aiter_sse():
@@ -353,6 +366,9 @@ async def get_card(
353366
get_http_args(context),
354367
context,
355368
)
369+
modified_kwargs = update_extension_header(
370+
modified_kwargs, self.extensions
371+
)
356372
response_data = await self._send_get_request(
357373
'/v1/card', {}, modified_kwargs
358374
)

tests/client/test_base_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def base_client(sample_agent_card, mock_transport):
5555
transport=mock_transport,
5656
consumers=[],
5757
middleware=[],
58+
extensions=[],
5859
)
5960

6061

tests/client/transports/test_utils.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,70 +5,66 @@
55

66

77
class TestUtils:
8-
def test_update_extension_header_no_initial_headers(self):
9-
extensions = ['test_extension_1', 'test_extension_2']
10-
11-
http_kwargs = {}
12-
result_kwargs = update_extension_header(http_kwargs, extensions)
13-
header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER]
14-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
15-
actual_extensions = set(actual_extensions_list)
16-
17-
expected_extensions = {
18-
'test_extension_1',
19-
'test_extension_2',
20-
}
21-
assert len(actual_extensions_list) == 2
22-
assert actual_extensions == expected_extensions
23-
248
@pytest.mark.parametrize(
25-
'existing_header, expected_count',
9+
'extensions, existing_header, expected_extensions, expected_count',
2610
[
27-
('test_extension_1, test_extension_2', 3),
28-
('test_extension_1,test_extension_2', 3),
29-
('test_extension_1', 3),
11+
(
12+
['test_extension_1', 'test_extension_2'],
13+
'',
14+
{
15+
'test_extension_1',
16+
'test_extension_2',
17+
},
18+
2,
19+
),
20+
(
21+
['test_extension_1', 'test_extension_2'],
22+
'test_extension_2, test_extension_3',
23+
{
24+
'test_extension_1',
25+
'test_extension_2',
26+
'test_extension_3',
27+
},
28+
3,
29+
),
30+
(
31+
['test_extension_1', 'test_extension_2'],
32+
'test_extension_3',
33+
{
34+
'test_extension_1',
35+
'test_extension_2',
36+
'test_extension_3',
37+
},
38+
3,
39+
),
3040
],
3141
)
3242
def test_update_extension_header_merge_with_existing_extensions(
3343
self,
44+
extensions: list[str],
3445
existing_header: str,
46+
expected_extensions: set[str],
3547
expected_count: int,
3648
):
37-
extensions = ['test_extension_2', 'test_extension_3']
3849
http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}}
3950
result_kwargs = update_extension_header(http_kwargs, extensions)
4051

4152
header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER]
4253
actual_extensions_list = [e.strip() for e in header_value.split(',')]
4354
actual_extensions = set(actual_extensions_list)
4455

45-
expected_extensions = {
46-
'test_extension_1',
47-
'test_extension_2',
48-
'test_extension_3',
49-
}
5056
assert len(actual_extensions_list) == expected_count
5157
assert actual_extensions == expected_extensions
5258

5359
def test_update_extension_header_with_other_headers(self):
54-
extensions = ['test_extension_1']
60+
extensions = ['test_extension']
5561
http_kwargs = {'headers': {'X_Other': 'Test'}}
5662
result_kwargs = update_extension_header(http_kwargs, extensions)
5763
headers = result_kwargs.get('headers', {})
5864
assert HTTP_EXTENSION_HEADER in headers
59-
assert headers[HTTP_EXTENSION_HEADER] == 'test_extension_1'
65+
assert headers[HTTP_EXTENSION_HEADER] == 'test_extension'
6066
assert headers['X_Other'] == 'Test'
6167

62-
def test_update_extension_header_with_existing_other_headers(self):
63-
extensions = ['test_extension_1']
64-
http_kwargs = {'headers': {'X_Other': 'Test'}}
65-
result_kwargs = update_extension_header(http_kwargs, extensions)
66-
assert (
67-
result_kwargs['headers'][HTTP_EXTENSION_HEADER]
68-
== 'test_extension_1'
69-
)
70-
assert result_kwargs['headers']['X_Other'] == 'Test'
71-
7268
def test_update_extension_header_no_extensions(self):
7369
http_kwargs = {'headers': {'X_Other': 'Test'}}
7470
result_kwargs = update_extension_header(http_kwargs, None)

0 commit comments

Comments
 (0)