Skip to content

Commit 97eec52

Browse files
committed
refactor: streamline extension header handling in JsonRpcTransport and RestTransport tests.
Remove redundant code from client_factory.py
1 parent 3144f43 commit 97eec52

File tree

3 files changed

+93
-40
lines changed

3 files changed

+93
-40
lines changed

src/a2a/client/client_factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def _register_defaults(
7676
self.register(
7777
TransportProtocol.jsonrpc,
7878
lambda card, url, config, interceptors: JsonRpcTransport(
79-
httpx_client=config.httpx_client or httpx.AsyncClient(),
80-
agent_card=card,
81-
url=url,
82-
interceptors=interceptors,
83-
client_extensions=config.extensions or None,
79+
config.httpx_client or httpx.AsyncClient(),
80+
card,
81+
url,
82+
interceptors,
83+
config.extensions or None,
8484
),
8585
)
8686
if TransportProtocol.http_json in supported:

tests/client/test_jsonrpc_client.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -800,10 +800,15 @@ def test_update_extension_header_no_initial_headers(
800800
)
801801
http_kwargs = {}
802802
result_kwargs = client._update_extension_header(http_kwargs)
803-
actual_extensions = set(
804-
result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ')
805-
)
806-
expected_extensions = {'test_extension_1', 'test_extension_2'}
803+
header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER]
804+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
805+
actual_extensions = set(actual_extensions_list)
806+
807+
expected_extensions = {
808+
'test_extension_1',
809+
'test_extension_2',
810+
}
811+
assert len(actual_extensions_list) == 2
807812
assert actual_extensions == expected_extensions
808813

809814
def test_update_extension_header_with_existing_other_headers(
@@ -823,25 +828,34 @@ def test_update_extension_header_with_existing_other_headers(
823828
)
824829
assert result_kwargs['headers']['X_Other'] == 'Test'
825830

831+
@pytest.mark.parametrize(
832+
'existing_header, expected_count',
833+
[
834+
('test_extension_2, test_extension_3', 3),
835+
('test_extension_2,test_extension_3', 3),
836+
('test_extension_3', 3),
837+
],
838+
)
826839
def test_update_extension_header_merge_with_existing_extensions(
827-
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
840+
self,
841+
mock_httpx_client: AsyncMock,
842+
mock_agent_card: MagicMock,
843+
existing_header: str,
844+
expected_count: int,
828845
):
829846
extensions = ['test_extension_1', 'test_extension_2']
830847
client = JsonRpcTransport(
831848
httpx_client=mock_httpx_client,
832849
agent_card=mock_agent_card,
833850
client_extensions=extensions,
834851
)
835-
http_kwargs = {
836-
'headers': {
837-
HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3'
838-
}
839-
}
852+
http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}}
840853
result_kwargs = client._update_extension_header(http_kwargs)
841-
actual_extensions_list = result_kwargs['headers'][
842-
HTTP_EXTENSION_HEADER
843-
].split(', ')
854+
855+
header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER]
856+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
844857
actual_extensions = set(actual_extensions_list)
858+
845859
expected_extensions = {
846860
'test_extension_1',
847861
'test_extension_2',
@@ -853,7 +867,11 @@ def test_update_extension_header_merge_with_existing_extensions(
853867
def test_update_extension_header_no_client_extensions(
854868
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
855869
):
856-
client = JsonRpcTransport(mock_httpx_client, None, mock_agent_card)
870+
client = JsonRpcTransport(
871+
httpx_client=mock_httpx_client,
872+
agent_card=mock_agent_card,
873+
client_extensions=None,
874+
)
857875
http_kwargs = {'headers': {'X_Other': 'Test'}}
858876
result_kwargs = client._update_extension_header(http_kwargs)
859877
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
@@ -862,7 +880,11 @@ def test_update_extension_header_no_client_extensions(
862880
def test_update_extension_header_empty_client_extensions(
863881
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
864882
):
865-
client = JsonRpcTransport(mock_httpx_client, [], mock_agent_card)
883+
client = JsonRpcTransport(
884+
httpx_client=mock_httpx_client,
885+
agent_card=mock_agent_card,
886+
client_extensions=[],
887+
)
866888
http_kwargs = {'headers': {'X_Other': 'Test'}}
867889
result_kwargs = client._update_extension_header(http_kwargs)
868890
assert HTTP_EXTENSION_HEADER not in result_kwargs['headers']
@@ -876,8 +898,8 @@ async def test_send_message_with_extensions(
876898
extensions = ['test_extension_1', 'test_extension_2']
877899
client = JsonRpcTransport(
878900
httpx_client=mock_httpx_client,
879-
client_extensions=extensions,
880901
agent_card=mock_agent_card,
902+
client_extensions=extensions,
881903
)
882904
params = MessageSendParams(
883905
message=create_text_message_object(content='Hello')
@@ -898,10 +920,18 @@ async def test_send_message_with_extensions(
898920

899921
mock_httpx_client.post.assert_called_once()
900922
_, mock_kwargs = mock_httpx_client.post.call_args
923+
901924
headers = mock_kwargs.get('headers', {})
902925
assert HTTP_EXTENSION_HEADER in headers
903-
actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', '))
904-
expected_extensions = {'test_extension_1', 'test_extension_2'}
926+
header_value = headers[HTTP_EXTENSION_HEADER]
927+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
928+
actual_extensions = set(actual_extensions_list)
929+
930+
expected_extensions = {
931+
'test_extension_1',
932+
'test_extension_2',
933+
}
934+
assert len(actual_extensions_list) == 2
905935
assert actual_extensions == expected_extensions
906936

907937
@pytest.mark.asyncio
@@ -916,8 +946,8 @@ async def test_send_message_streaming_with_extensions(
916946
extensions = ['test_extension']
917947
client = JsonRpcTransport(
918948
httpx_client=mock_httpx_client,
919-
client_extensions=extensions,
920949
agent_card=mock_agent_card,
950+
client_extensions=extensions,
921951
)
922952
params = MessageSendParams(
923953
message=create_text_message_object(content='Hello stream')

tests/client/test_rest_client.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,35 +44,51 @@ def test_update_extension_header_no_initial_headers(
4444
)
4545
http_kwargs = {}
4646
result_kwargs = client._update_extension_header(http_kwargs)
47-
actual_extensions = set(
48-
result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ')
49-
)
50-
expected_extensions = {'test_extension_1', 'test_extension_2'}
47+
header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER]
48+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
49+
actual_extensions = set(actual_extensions_list)
50+
51+
expected_extensions = {
52+
'test_extension_1',
53+
'test_extension_2',
54+
}
55+
assert len(actual_extensions_list) == 2
5156
assert actual_extensions == expected_extensions
5257

58+
@pytest.mark.parametrize(
59+
'existing_header, expected_count',
60+
[
61+
('test_extension_1, test_extension_2', 3),
62+
('test_extension_1,test_extension_2', 3),
63+
('test_extension_1', 3),
64+
],
65+
)
5366
def test_update_extension_header_merge_with_existing_extensions(
54-
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
67+
self,
68+
mock_httpx_client: AsyncMock,
69+
mock_agent_card: MagicMock,
70+
existing_header: str,
71+
expected_count: int,
5572
):
5673
extensions = ['test_extension_2', 'test_extension_3']
5774
client = RestTransport(
5875
httpx_client=mock_httpx_client,
5976
agent_card=mock_agent_card,
6077
client_extensions=extensions,
6178
)
62-
http_kwargs = {
63-
'headers': {
64-
HTTP_EXTENSION_HEADER: 'test_extension_1, test_extension_2'
65-
}
66-
}
79+
http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}}
6780
result_kwargs = client._update_extension_header(http_kwargs)
68-
actual_extensions = set(
69-
result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ')
70-
)
81+
82+
header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER]
83+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
84+
actual_extensions = set(actual_extensions_list)
85+
7186
expected_extensions = {
7287
'test_extension_1',
7388
'test_extension_2',
7489
'test_extension_3',
7590
}
91+
assert len(actual_extensions_list) == expected_count
7692
assert actual_extensions == expected_extensions
7793

7894
def test_update_extension_header_with_other_headers(
@@ -124,8 +140,15 @@ async def test_send_message_with_extensions(
124140

125141
headers = kwargs.get('headers', {})
126142
assert HTTP_EXTENSION_HEADER in headers
127-
actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', '))
128-
expected_extensions = {'test_extension_1', 'test_extension_2'}
143+
header_value = kwargs['headers'][HTTP_EXTENSION_HEADER]
144+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
145+
actual_extensions = set(actual_extensions_list)
146+
147+
expected_extensions = {
148+
'test_extension_1',
149+
'test_extension_2',
150+
}
151+
assert len(actual_extensions_list) == 2
129152
assert actual_extensions == expected_extensions
130153

131154
@pytest.mark.asyncio
@@ -140,8 +163,8 @@ async def test_send_message_streaming_with_extensions(
140163
extensions = ['test_extension']
141164
client = RestTransport(
142165
httpx_client=mock_httpx_client,
143-
client_extensions=extensions,
144166
agent_card=mock_agent_card,
167+
client_extensions=extensions,
145168
)
146169
params = MessageSendParams(
147170
message=create_text_message_object(content='Hello stream')

0 commit comments

Comments
 (0)