Skip to content

Commit bdc81a1

Browse files
committed
fix(jsonrpc, rest): fix get_card methods in json-rpc and rest transports. Headers are now updated with extensions before the get_agent_card call.
1 parent dbc73e8 commit bdc81a1

File tree

4 files changed

+204
-15
lines changed

4 files changed

+204
-15
lines changed

src/a2a/client/transports/jsonrpc.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,14 @@ async def get_card(
378378
extensions: list[str] | None = None,
379379
) -> AgentCard:
380380
"""Retrieves the agent's card."""
381+
modified_kwargs = update_extension_header(
382+
self._get_http_args(context),
383+
extensions if extensions is not None else self.extensions,
384+
)
381385
card = self.agent_card
382386
if not card:
383387
resolver = A2ACardResolver(self.httpx_client, self.url)
384-
card = await resolver.get_agent_card(
385-
http_kwargs=self._get_http_args(context)
386-
)
388+
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
387389
self._needs_extended_card = (
388390
card.supports_authenticated_extended_card
389391
)
@@ -393,10 +395,6 @@ async def get_card(
393395
return card
394396

395397
request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))
396-
modified_kwargs = update_extension_header(
397-
self._get_http_args(context),
398-
extensions if extensions is not None else self.extensions,
399-
)
400398
payload, modified_kwargs = await self._apply_interceptors(
401399
request.method,
402400
request.model_dump(mode='json', exclude_none=True),

src/a2a/client/transports/rest.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,14 @@ async def get_card(
370370
extensions: list[str] | None = None,
371371
) -> AgentCard:
372372
"""Retrieves the agent's card."""
373+
modified_kwargs = update_extension_header(
374+
self._get_http_args(context),
375+
extensions if extensions is not None else self.extensions,
376+
)
373377
card = self.agent_card
374378
if not card:
375379
resolver = A2ACardResolver(self.httpx_client, self.url)
376-
card = await resolver.get_agent_card(
377-
http_kwargs=self._get_http_args(context)
378-
)
380+
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
379381
self._needs_extended_card = (
380382
card.supports_authenticated_extended_card
381383
)
@@ -384,10 +386,6 @@ async def get_card(
384386
if not self._needs_extended_card:
385387
return card
386388

387-
modified_kwargs = update_extension_header(
388-
self._get_http_args(context),
389-
extensions if extensions is not None else self.extensions,
390-
)
391389
_, modified_kwargs = await self._apply_interceptors(
392390
{},
393391
modified_kwargs,

tests/client/transports/test_jsonrpc_client.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,3 +875,87 @@ async def test_send_message_streaming_with_new_extensions(
875875
assert (
876876
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
877877
)
878+
879+
@pytest.mark.asyncio
880+
async def test_get_card_no_card_provided_with_extensions(
881+
self, mock_httpx_client: AsyncMock
882+
):
883+
"""Test get_card with extensions set in Client when no card is initially provided.
884+
Tests that the extensions are added to the HTTP GET request."""
885+
extensions = [
886+
'https://example.com/test-ext/v1',
887+
'https://example.com/test-ext/v2',
888+
]
889+
client = JsonRpcTransport(
890+
httpx_client=mock_httpx_client,
891+
url=TestJsonRpcTransport.AGENT_URL,
892+
extensions=extensions,
893+
)
894+
mock_response = AsyncMock(spec=httpx.Response)
895+
mock_response.status_code = 200
896+
mock_response.json.return_value = AGENT_CARD.model_dump(mode='json')
897+
mock_httpx_client.get.return_value = mock_response
898+
899+
await client.get_card()
900+
901+
mock_httpx_client.get.assert_called_once()
902+
_, mock_kwargs = mock_httpx_client.get.call_args
903+
904+
headers = mock_kwargs.get('headers', {})
905+
assert HTTP_EXTENSION_HEADER in headers
906+
header_value = headers[HTTP_EXTENSION_HEADER]
907+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
908+
actual_extensions = set(actual_extensions_list)
909+
910+
expected_extensions = {
911+
'https://example.com/test-ext/v1',
912+
'https://example.com/test-ext/v2',
913+
}
914+
assert len(actual_extensions_list) == 2
915+
assert actual_extensions == expected_extensions
916+
917+
@pytest.mark.asyncio
918+
async def test_get_card_with_extended_card_support_with_extensions(
919+
self, mock_httpx_client: AsyncMock
920+
):
921+
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
922+
Tests that the extensions are added to the RPC request."""
923+
extensions = [
924+
'https://example.com/test-ext/v1',
925+
'https://example.com/test-ext/v2',
926+
]
927+
agent_card = AGENT_CARD.model_copy(
928+
update={'supports_authenticated_extended_card': True}
929+
)
930+
client = JsonRpcTransport(
931+
httpx_client=mock_httpx_client,
932+
agent_card=agent_card,
933+
extensions=extensions,
934+
)
935+
936+
rpc_response = {
937+
'id': '123',
938+
'jsonrpc': '2.0',
939+
'result': AGENT_CARD_EXTENDED.model_dump(mode='json'),
940+
}
941+
with patch.object(
942+
client, '_send_request', new_callable=AsyncMock
943+
) as mock_send_request:
944+
mock_send_request.return_value = rpc_response
945+
await client.get_card(extensions=extensions)
946+
947+
mock_send_request.assert_called_once()
948+
_, mock_kwargs = mock_send_request.call_args[0]
949+
950+
headers = mock_kwargs.get('headers', {})
951+
assert HTTP_EXTENSION_HEADER in headers
952+
header_value = headers[HTTP_EXTENSION_HEADER]
953+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
954+
actual_extensions = set(actual_extensions_list)
955+
956+
expected_extensions = {
957+
'https://example.com/test-ext/v1',
958+
'https://example.com/test-ext/v2',
959+
}
960+
assert len(actual_extensions_list) == 2
961+
assert actual_extensions == expected_extensions

tests/client/transports/test_rest_client.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from a2a.client import create_text_message_object
1010
from a2a.client.transports.rest import RestTransport
1111
from a2a.extensions.common import HTTP_EXTENSION_HEADER
12-
from a2a.types import AgentCard, MessageSendParams, Role
12+
from a2a.types import (
13+
AgentCapabilities,
14+
AgentCard,
15+
AgentSkill,
16+
MessageSendParams,
17+
Role,
18+
)
1319

1420

1521
@pytest.fixture
@@ -119,3 +125,106 @@ async def test_send_message_streaming_with_new_extensions(
119125
assert (
120126
headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2'
121127
)
128+
129+
@pytest.mark.asyncio
130+
async def test_get_card_no_card_provided_with_extensions(
131+
self, mock_httpx_client: AsyncMock
132+
):
133+
"""Test get_card with extensions set in Client when no card is initially provided.
134+
Tests that the extensions are added to the HTTP GET request."""
135+
extensions = [
136+
'https://example.com/test-ext/v1',
137+
'https://example.com/test-ext/v2',
138+
]
139+
client = RestTransport(
140+
httpx_client=mock_httpx_client,
141+
url='http://agent.example.com/api',
142+
extensions=extensions,
143+
)
144+
145+
mock_response = AsyncMock(spec=httpx.Response)
146+
mock_response.status_code = 200
147+
mock_response.json.return_value = {
148+
'name': 'Test Agent',
149+
'description': 'Test Agent Description',
150+
'url': 'http://agent.example.com/api',
151+
'version': '1.0.0',
152+
'default_input_modes': ['text'],
153+
'default_output_modes': ['text'],
154+
'capabilities': AgentCapabilities().model_dump(),
155+
'skills': [],
156+
}
157+
mock_httpx_client.get.return_value = mock_response
158+
159+
await client.get_card()
160+
161+
mock_httpx_client.get.assert_called_once()
162+
_, mock_kwargs = mock_httpx_client.get.call_args
163+
164+
headers = mock_kwargs.get('headers', {})
165+
assert HTTP_EXTENSION_HEADER in headers
166+
header_value = headers[HTTP_EXTENSION_HEADER]
167+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
168+
actual_extensions = set(actual_extensions_list)
169+
170+
expected_extensions = {
171+
'https://example.com/test-ext/v1',
172+
'https://example.com/test-ext/v2',
173+
}
174+
assert len(actual_extensions_list) == 2
175+
assert actual_extensions == expected_extensions
176+
177+
@pytest.mark.asyncio
178+
async def test_get_card_with_extended_card_support_with_extensions(
179+
self, mock_httpx_client: AsyncMock
180+
):
181+
"""Test get_card with extensions passed to get_card call when extended card support is enabled.
182+
Tests that the extensions are added to the GET request."""
183+
extensions = [
184+
'https://example.com/test-ext/v1',
185+
'https://example.com/test-ext/v2',
186+
]
187+
agent_card = AgentCard(
188+
name='Test Agent',
189+
description='Test Agent Description',
190+
url='http://agent.example.com/api',
191+
version='1.0.0',
192+
default_input_modes=['text'],
193+
default_output_modes=['text'],
194+
capabilities=AgentCapabilities(),
195+
skills=[],
196+
supports_authenticated_extended_card=True,
197+
)
198+
client = RestTransport(
199+
httpx_client=mock_httpx_client,
200+
agent_card=agent_card,
201+
)
202+
203+
mock_response = AsyncMock(spec=httpx.Response)
204+
mock_response.status_code = 200
205+
mock_response.json.return_value = agent_card.model_dump(mode='json')
206+
mock_httpx_client.send.return_value = mock_response
207+
208+
with patch.object(
209+
client, '_send_get_request', new_callable=AsyncMock
210+
) as mock_send_get_request:
211+
mock_send_get_request.return_value = agent_card.model_dump(
212+
mode='json'
213+
)
214+
await client.get_card(extensions=extensions)
215+
216+
mock_send_get_request.assert_called_once()
217+
_, _, mock_kwargs = mock_send_get_request.call_args[0]
218+
219+
headers = mock_kwargs.get('headers', {})
220+
assert HTTP_EXTENSION_HEADER in headers
221+
header_value = headers[HTTP_EXTENSION_HEADER]
222+
actual_extensions_list = [e.strip() for e in header_value.split(',')]
223+
actual_extensions = set(actual_extensions_list)
224+
225+
expected_extensions = {
226+
'https://example.com/test-ext/v1',
227+
'https://example.com/test-ext/v2',
228+
}
229+
assert len(actual_extensions_list) == 2
230+
assert actual_extensions == expected_extensions

0 commit comments

Comments
 (0)