Skip to content

Commit 900133d

Browse files
committed
refractor: reduce redundancy by extracting duplicated code into a shared helper function.
1 parent bdc81a1 commit 900133d

File tree

2 files changed

+68
-80
lines changed

2 files changed

+68
-80
lines changed

tests/client/transports/test_jsonrpc_client.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ async def async_iterable_from_list(
114114
yield item
115115

116116

117+
def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
118+
headers = mock_kwargs.get('headers', {})
119+
assert HTTP_EXTENSION_HEADER in headers
120+
header_value = headers[HTTP_EXTENSION_HEADER]
121+
actual_extensions = {e.strip() for e in header_value.split(',')}
122+
assert actual_extensions == expected_extensions
123+
124+
117125
class TestA2ACardResolver:
118126
BASE_URL = 'http://example.com'
119127
AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH
@@ -823,18 +831,13 @@ async def test_send_message_with_default_extensions(
823831
mock_httpx_client.post.assert_called_once()
824832
_, mock_kwargs = mock_httpx_client.post.call_args
825833

826-
headers = mock_kwargs.get('headers', {})
827-
assert HTTP_EXTENSION_HEADER in headers
828-
header_value = headers[HTTP_EXTENSION_HEADER]
829-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
830-
actual_extensions = set(actual_extensions_list)
831-
832-
expected_extensions = {
833-
'https://example.com/test-ext/v1',
834-
'https://example.com/test-ext/v2',
835-
}
836-
assert len(actual_extensions_list) == 2
837-
assert actual_extensions == expected_extensions
834+
_assert_extensions_header(
835+
mock_kwargs,
836+
{
837+
'https://example.com/test-ext/v1',
838+
'https://example.com/test-ext/v2',
839+
},
840+
)
838841

839842
@pytest.mark.asyncio
840843
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
@@ -870,10 +873,11 @@ async def test_send_message_streaming_with_new_extensions(
870873
mock_aconnect_sse.assert_called_once()
871874
_, kwargs = mock_aconnect_sse.call_args
872875

873-
headers = kwargs.get('headers', {})
874-
assert HTTP_EXTENSION_HEADER in headers
875-
assert (
876-
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
876+
_assert_extensions_header(
877+
kwargs,
878+
{
879+
'https://example.com/test-ext/v2',
880+
},
877881
)
878882

879883
@pytest.mark.asyncio
@@ -901,18 +905,13 @@ async def test_get_card_no_card_provided_with_extensions(
901905
mock_httpx_client.get.assert_called_once()
902906
_, mock_kwargs = mock_httpx_client.get.call_args
903907

904-
headers = mock_kwargs.get('headers', {})
905-
assert HTTP_EXTENSION_HEADER in headers
906-
header_value = headers[HTTP_EXTENSION_HEADER]
907-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
908-
actual_extensions = set(actual_extensions_list)
909-
910-
expected_extensions = {
911-
'https://example.com/test-ext/v1',
912-
'https://example.com/test-ext/v2',
913-
}
914-
assert len(actual_extensions_list) == 2
915-
assert actual_extensions == expected_extensions
908+
_assert_extensions_header(
909+
mock_kwargs,
910+
{
911+
'https://example.com/test-ext/v1',
912+
'https://example.com/test-ext/v2',
913+
},
914+
)
916915

917916
@pytest.mark.asyncio
918917
async def test_get_card_with_extended_card_support_with_extensions(
@@ -947,15 +946,10 @@ async def test_get_card_with_extended_card_support_with_extensions(
947946
mock_send_request.assert_called_once()
948947
_, mock_kwargs = mock_send_request.call_args[0]
949948

950-
headers = mock_kwargs.get('headers', {})
951-
assert HTTP_EXTENSION_HEADER in headers
952-
header_value = headers[HTTP_EXTENSION_HEADER]
953-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
954-
actual_extensions = set(actual_extensions_list)
955-
956-
expected_extensions = {
957-
'https://example.com/test-ext/v1',
958-
'https://example.com/test-ext/v2',
959-
}
960-
assert len(actual_extensions_list) == 2
961-
assert actual_extensions == expected_extensions
949+
_assert_extensions_header(
950+
mock_kwargs,
951+
{
952+
'https://example.com/test-ext/v1',
953+
'https://example.com/test-ext/v2',
954+
},
955+
)

tests/client/transports/test_rest_client.py

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ async def async_iterable_from_list(
3838
yield item
3939

4040

41+
def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
42+
headers = mock_kwargs.get('headers', {})
43+
assert HTTP_EXTENSION_HEADER in headers
44+
header_value = headers[HTTP_EXTENSION_HEADER]
45+
actual_extensions = {e.strip() for e in header_value.split(',')}
46+
assert actual_extensions == expected_extensions
47+
48+
4149
class TestRestTransportExtensions:
4250
@pytest.mark.asyncio
4351
async def test_send_message_with_default_extensions(
@@ -73,18 +81,13 @@ async def test_send_message_with_default_extensions(
7381
mock_build_request.assert_called_once()
7482
_, kwargs = mock_build_request.call_args
7583

76-
headers = kwargs.get('headers', {})
77-
assert HTTP_EXTENSION_HEADER in headers
78-
header_value = kwargs['headers'][HTTP_EXTENSION_HEADER]
79-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
80-
actual_extensions = set(actual_extensions_list)
81-
82-
expected_extensions = {
83-
'https://example.com/test-ext/v1',
84-
'https://example.com/test-ext/v2',
85-
}
86-
assert len(actual_extensions_list) == 2
87-
assert actual_extensions == expected_extensions
84+
_assert_extensions_header(
85+
kwargs,
86+
{
87+
'https://example.com/test-ext/v1',
88+
'https://example.com/test-ext/v2',
89+
},
90+
)
8891

8992
@pytest.mark.asyncio
9093
@patch('a2a.client.transports.rest.aconnect_sse')
@@ -120,10 +123,11 @@ async def test_send_message_streaming_with_new_extensions(
120123
mock_aconnect_sse.assert_called_once()
121124
_, kwargs = mock_aconnect_sse.call_args
122125

123-
headers = kwargs.get('headers', {})
124-
assert HTTP_EXTENSION_HEADER in headers
125-
assert (
126-
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
126+
_assert_extensions_header(
127+
kwargs,
128+
{
129+
'https://example.com/test-ext/v2',
130+
},
127131
)
128132

129133
@pytest.mark.asyncio
@@ -161,18 +165,13 @@ async def test_get_card_no_card_provided_with_extensions(
161165
mock_httpx_client.get.assert_called_once()
162166
_, mock_kwargs = mock_httpx_client.get.call_args
163167

164-
headers = mock_kwargs.get('headers', {})
165-
assert HTTP_EXTENSION_HEADER in headers
166-
header_value = headers[HTTP_EXTENSION_HEADER]
167-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
168-
actual_extensions = set(actual_extensions_list)
169-
170-
expected_extensions = {
171-
'https://example.com/test-ext/v1',
172-
'https://example.com/test-ext/v2',
173-
}
174-
assert len(actual_extensions_list) == 2
175-
assert actual_extensions == expected_extensions
168+
_assert_extensions_header(
169+
mock_kwargs,
170+
{
171+
'https://example.com/test-ext/v1',
172+
'https://example.com/test-ext/v2',
173+
},
174+
)
176175

177176
@pytest.mark.asyncio
178177
async def test_get_card_with_extended_card_support_with_extensions(
@@ -216,15 +215,10 @@ async def test_get_card_with_extended_card_support_with_extensions(
216215
mock_send_get_request.assert_called_once()
217216
_, _, mock_kwargs = mock_send_get_request.call_args[0]
218217

219-
headers = mock_kwargs.get('headers', {})
220-
assert HTTP_EXTENSION_HEADER in headers
221-
header_value = headers[HTTP_EXTENSION_HEADER]
222-
actual_extensions_list = [e.strip() for e in header_value.split(',')]
223-
actual_extensions = set(actual_extensions_list)
224-
225-
expected_extensions = {
226-
'https://example.com/test-ext/v1',
227-
'https://example.com/test-ext/v2',
228-
}
229-
assert len(actual_extensions_list) == 2
230-
assert actual_extensions == expected_extensions
218+
_assert_extensions_header(
219+
mock_kwargs,
220+
{
221+
'https://example.com/test-ext/v1',
222+
'https://example.com/test-ext/v2',
223+
},
224+
)

0 commit comments

Comments
 (0)