Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/a2a/client/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ClientCallContext(BaseModel):
"""

state: MutableMapping[str, Any] = Field(default_factory=dict)
timeout: float | None = None


class ClientCallInterceptor(ABC):
Expand Down
111 changes: 72 additions & 39 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,8 @@
extensions: list[str] | None = None,
) -> SendMessageResponse:
"""Sends a non-streaming message request to the agent."""
return await self.stub.SendMessage(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.SendMessage, request, context, extensions
)

@_handle_grpc_stream_exception
Expand All @@ -148,14 +147,9 @@
extensions: list[str] | None = None,
) -> AsyncGenerator[StreamResponse]:
"""Sends a streaming message request to the agent and yields responses as they arrive."""
stream = self.stub.SendStreamingMessage(
request,
metadata=self._get_grpc_metadata(extensions),
)
while True:
response = await stream.read()
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
break
async for response in self._call_grpc_stream(
self.stub.SendStreamingMessage, request, context, extensions
):
yield response

@_handle_grpc_stream_exception
Expand All @@ -167,14 +161,9 @@
extensions: list[str] | None = None,
) -> AsyncGenerator[StreamResponse]:
"""Reconnects to get task updates."""
stream = self.stub.SubscribeToTask(
request,
metadata=self._get_grpc_metadata(extensions),
)
while True:
response = await stream.read()
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
break
async for response in self._call_grpc_stream(
self.stub.SubscribeToTask, request, context, extensions
):
yield response

@_handle_grpc_exception
Expand All @@ -186,9 +175,8 @@
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
return await self.stub.GetTask(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.GetTask, request, context, extensions
)

@_handle_grpc_exception
Expand All @@ -200,9 +188,8 @@
extensions: list[str] | None = None,
) -> ListTasksResponse:
"""Retrieves tasks for an agent."""
return await self.stub.ListTasks(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.ListTasks, request, context, extensions
)

@_handle_grpc_exception
Expand All @@ -214,9 +201,8 @@
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
return await self.stub.CancelTask(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.CancelTask, request, context, extensions
)

@_handle_grpc_exception
Expand All @@ -228,9 +214,11 @@
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
return await self.stub.CreateTaskPushNotificationConfig(
return await self._call_grpc(
self.stub.CreateTaskPushNotificationConfig,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
Expand All @@ -242,9 +230,11 @@
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
return await self.stub.GetTaskPushNotificationConfig(
return await self._call_grpc(
self.stub.GetTaskPushNotificationConfig,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
Expand All @@ -256,9 +246,11 @@
extensions: list[str] | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task."""
return await self.stub.ListTaskPushNotificationConfigs(
return await self._call_grpc(
self.stub.ListTaskPushNotificationConfigs,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
Expand All @@ -270,24 +262,25 @@
extensions: list[str] | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task."""
await self.stub.DeleteTaskPushNotificationConfig(
await self._call_grpc(
self.stub.DeleteTaskPushNotificationConfig,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = await self.stub.GetExtendedAgentCard(
request,
metadata=self._get_grpc_metadata(extensions),
card = await self._call_grpc(

Check notice on line 282 in src/a2a/client/transports/grpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/tenant_decorator.py (173-389)
self.stub.GetExtendedAgentCard, request, context, extensions
)

if signature_verifier:
Expand Down Expand Up @@ -315,3 +308,43 @@
)

return metadata

def _get_grpc_timeout(
self, context: ClientCallContext | None
) -> float | None:
return context.timeout if context else None

async def _call_grpc(
self,
method: Callable[..., Any],
request: Any,
context: ClientCallContext | None,
extensions: list[str] | None,
**kwargs: Any,
) -> Any:
return await method(
request,
metadata=self._get_grpc_metadata(extensions),
timeout=self._get_grpc_timeout(context),
**kwargs,
)

async def _call_grpc_stream(
self,
method: Callable[..., Any],
request: Any,
context: ClientCallContext | None,
extensions: list[str] | None,
**kwargs: Any,
) -> AsyncGenerator[StreamResponse]:
stream = method(
request,
metadata=self._get_grpc_metadata(extensions),
timeout=self._get_grpc_timeout(context),
**kwargs,
)
while True:
response = await stream.read()
if response == grpc.aio.EOF:

Check failure on line 348 in src/a2a/client/transports/grpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

"EOF" is not a known attribute of module "..aio" (reportAttributeAccessIssue)
break
yield response
5 changes: 4 additions & 1 deletion src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,15 @@
self.agent_card,
context,
)
return final_request_payload, final_http_kwargs

def _get_http_args(
self, context: ClientCallContext | None
) -> dict[str, Any] | None:
return context.state.get('http_kwargs') if context else None
http_kwargs: dict[str, Any] = {}
if context and context.timeout is not None:
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
return http_kwargs if http_kwargs else None

Check notice on line 465 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (395-405)

def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""
Expand Down
5 changes: 4 additions & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,18 +377,21 @@
def _get_http_args(
self, context: ClientCallContext | None
) -> dict[str, Any] | None:
return context.state.get('http_kwargs') if context else None
http_kwargs: dict[str, Any] = {}
if context and context.timeout is not None:
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
return http_kwargs

async def _prepare_send_message(
self,
request: SendMessageRequest,
context: ClientCallContext | None,
extensions: list[str] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
payload = MessageToDict(request)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,

Check notice on line 394 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (438-446)
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
Expand Down
34 changes: 34 additions & 0 deletions tests/client/transports/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,32 @@ async def test_send_message_task_response(
assert response.task.id == sample_task.id


@pytest.mark.asyncio
async def test_send_message_with_timeout_context(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_message_send_params: SendMessageRequest,
sample_task: Task,
) -> None:
"""Test send_message passes context timeout to grpc stub."""
from a2a.client.middleware import ClientCallContext

mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
task=sample_task
)
context = ClientCallContext(timeout=12.5)

await grpc_transport.send_message(
sample_message_send_params,
context=context,
)

mock_grpc_stub.SendMessage.assert_awaited_once()
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert 'timeout' in kwargs
assert kwargs['timeout'] == 12.5


@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
@pytest.mark.asyncio
async def test_grpc_mapped_errors(
Expand Down Expand Up @@ -360,6 +386,7 @@ async def test_get_task(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert response.id == sample_task.id

Expand Down Expand Up @@ -389,6 +416,7 @@ async def test_list_tasks(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert result.total_size == 2
assert not result.next_page_token
Expand Down Expand Up @@ -417,6 +445,7 @@ async def test_get_task_with_history(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)


Expand All @@ -443,6 +472,7 @@ async def test_cancel_task(
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3'),
],
timeout=None,
)
assert response.status.state == TaskState.TASK_STATE_CANCELED

Expand Down Expand Up @@ -476,6 +506,7 @@ async def test_create_task_push_notification_config_with_valid_task(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert response.task_id == sample_task_push_notification_config.task_id

Expand Down Expand Up @@ -539,6 +570,7 @@ async def test_get_task_push_notification_config_with_valid_task(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert response.task_id == sample_task_push_notification_config.task_id

Expand Down Expand Up @@ -593,6 +625,7 @@ async def test_list_task_push_notification_configs(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert len(response.configs) == 1
assert response.configs[0].task_id == 'task-1'
Expand Down Expand Up @@ -626,6 +659,7 @@ async def test_delete_task_push_notification_config(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)


Expand Down
26 changes: 26 additions & 0 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,32 @@ async def test_send_message_json_decode_error(
with pytest.raises(A2AClientError):
await transport.send_message(request)

@pytest.mark.asyncio
async def test_send_message_with_timeout_context(
self, transport, mock_httpx_client
):
"""Test that send_message passes context timeout to build_request."""
from a2a.client.middleware import ClientCallContext

mock_response = MagicMock()
mock_response.json.return_value = {
'jsonrpc': '2.0',
'id': '1',
'result': {},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.send.return_value = mock_response

request = create_send_message_request()
context = ClientCallContext(timeout=15.0)

await transport.send_message(request, context=context)

mock_httpx_client.build_request.assert_called_once()
_, kwargs = mock_httpx_client.build_request.call_args
assert 'timeout' in kwargs
assert kwargs['timeout'] == httpx.Timeout(15.0)


class TestGetTask:
"""Tests for the get_task method."""
Expand Down
33 changes: 33 additions & 0 deletions tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,39 @@ async def test_rest_mapped_errors(
with pytest.raises(error_cls):
await client.send_message(request=params)

@pytest.mark.asyncio
async def test_send_message_with_timeout_context(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that send_message passes context timeout to build_request."""
from a2a.client.middleware import ClientCallContext

client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=mock_agent_card,
url='http://agent.example.com/api',
)
params = SendMessageRequest(
message=create_text_message_object(content='Hello')
)
context = ClientCallContext(timeout=10.0)

mock_build_request = MagicMock(
return_value=AsyncMock(spec=httpx.Request)
)
mock_httpx_client.build_request = mock_build_request

mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_httpx_client.send.return_value = mock_response

await client.send_message(request=params, context=context)

mock_build_request.assert_called_once()
_, kwargs = mock_build_request.call_args
assert 'timeout' in kwargs
assert kwargs['timeout'] == httpx.Timeout(10.0)


class TestRestTransportExtensions:
@pytest.mark.asyncio
Expand Down
Loading