|
7 | 7 | from httpx_sse import EventSource, ServerSentEvent |
8 | 8 |
|
9 | 9 | from a2a.client import create_text_message_object |
| 10 | +from a2a.client.errors import A2AClientHTTPError |
10 | 11 | from a2a.client.transports.rest import RestTransport |
11 | 12 | from a2a.extensions.common import HTTP_EXTENSION_HEADER |
12 | 13 | from a2a.types import ( |
13 | 14 | AgentCapabilities, |
14 | 15 | AgentCard, |
15 | | - AgentSkill, |
16 | 16 | MessageSendParams, |
17 | | - Role, |
18 | 17 | ) |
19 | 18 |
|
20 | 19 |
|
@@ -130,6 +129,45 @@ async def test_send_message_streaming_with_new_extensions( |
130 | 129 | }, |
131 | 130 | ) |
132 | 131 |
|
| 132 | + @pytest.mark.asyncio |
| 133 | + @patch('a2a.client.transports.rest.aconnect_sse') |
| 134 | + async def test_send_message_streaming_server_error_propagates( |
| 135 | + self, |
| 136 | + mock_aconnect_sse: AsyncMock, |
| 137 | + mock_httpx_client: AsyncMock, |
| 138 | + mock_agent_card: MagicMock, |
| 139 | + ): |
| 140 | + """Test that send_message_streaming propagates server errors (e.g., 403, 500) directly.""" |
| 141 | + client = RestTransport( |
| 142 | + httpx_client=mock_httpx_client, |
| 143 | + agent_card=mock_agent_card, |
| 144 | + ) |
| 145 | + params = MessageSendParams( |
| 146 | + message=create_text_message_object(content='Error stream') |
| 147 | + ) |
| 148 | + |
| 149 | + mock_event_source = AsyncMock(spec=EventSource) |
| 150 | + mock_response = MagicMock(spec=httpx.Response) |
| 151 | + mock_response.status_code = 403 |
| 152 | + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( |
| 153 | + 'Forbidden', |
| 154 | + request=httpx.Request('POST', 'http://test.url'), |
| 155 | + response=mock_response, |
| 156 | + ) |
| 157 | + mock_event_source.response = mock_response |
| 158 | + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) |
| 159 | + mock_aconnect_sse.return_value.__aenter__.return_value = ( |
| 160 | + mock_event_source |
| 161 | + ) |
| 162 | + |
| 163 | + with pytest.raises(A2AClientHTTPError) as exc_info: |
| 164 | + async for _ in client.send_message_streaming(request=params): |
| 165 | + pass |
| 166 | + |
| 167 | + assert exc_info.value.status_code == 403 |
| 168 | + |
| 169 | + mock_aconnect_sse.assert_called_once() |
| 170 | + |
133 | 171 | @pytest.mark.asyncio |
134 | 172 | async def test_get_card_no_card_provided_with_extensions( |
135 | 173 | self, mock_httpx_client: AsyncMock |
|
0 commit comments