Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/a2a/client/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ClientCallContext(BaseModel):
"""

state: MutableMapping[str, Any] = Field(default_factory=dict)
timeout: float | None = None


class ClientCallInterceptor(ABC):
Expand Down
111 changes: 72 additions & 39 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,8 @@
extensions: list[str] | None = None,
) -> SendMessageResponse:
"""Sends a non-streaming message request to the agent."""
return await self.stub.SendMessage(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.SendMessage, request, context, extensions
)

@_handle_grpc_stream_exception
Expand All @@ -148,14 +147,9 @@
extensions: list[str] | None = None,
) -> AsyncGenerator[StreamResponse]:
"""Sends a streaming message request to the agent and yields responses as they arrive."""
stream = self.stub.SendStreamingMessage(
request,
metadata=self._get_grpc_metadata(extensions),
)
while True:
response = await stream.read()
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
break
async for response in self._call_grpc_stream(
self.stub.SendStreamingMessage, request, context, extensions
):
yield response

@_handle_grpc_stream_exception
Expand All @@ -167,14 +161,9 @@
extensions: list[str] | None = None,
) -> AsyncGenerator[StreamResponse]:
"""Reconnects to get task updates."""
stream = self.stub.SubscribeToTask(
request,
metadata=self._get_grpc_metadata(extensions),
)
while True:
response = await stream.read()
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
break
async for response in self._call_grpc_stream(
self.stub.SubscribeToTask, request, context, extensions
):
yield response

@_handle_grpc_exception
Expand All @@ -186,9 +175,8 @@
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
return await self.stub.GetTask(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.GetTask, request, context, extensions
)

@_handle_grpc_exception
Expand All @@ -200,9 +188,8 @@
extensions: list[str] | None = None,
) -> ListTasksResponse:
"""Retrieves tasks for an agent."""
return await self.stub.ListTasks(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.ListTasks, request, context, extensions
)

@_handle_grpc_exception
Expand All @@ -214,9 +201,8 @@
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
return await self.stub.CancelTask(
request,
metadata=self._get_grpc_metadata(extensions),
return await self._call_grpc(
self.stub.CancelTask, request, context, extensions
)

@_handle_grpc_exception
Expand All @@ -228,9 +214,11 @@
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
return await self.stub.CreateTaskPushNotificationConfig(
return await self._call_grpc(
self.stub.CreateTaskPushNotificationConfig,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
Expand All @@ -242,9 +230,11 @@
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
return await self.stub.GetTaskPushNotificationConfig(
return await self._call_grpc(
self.stub.GetTaskPushNotificationConfig,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
Expand All @@ -256,9 +246,11 @@
extensions: list[str] | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task."""
return await self.stub.ListTaskPushNotificationConfigs(
return await self._call_grpc(
self.stub.ListTaskPushNotificationConfigs,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
Expand All @@ -270,24 +262,25 @@
extensions: list[str] | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task."""
await self.stub.DeleteTaskPushNotificationConfig(
await self._call_grpc(
self.stub.DeleteTaskPushNotificationConfig,
request,
metadata=self._get_grpc_metadata(extensions),
context,
extensions,
)

@_handle_grpc_exception
async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = await self.stub.GetExtendedAgentCard(
request,
metadata=self._get_grpc_metadata(extensions),
card = await self._call_grpc(

Check notice on line 282 in src/a2a/client/transports/grpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/tenant_decorator.py (173-389)
self.stub.GetExtendedAgentCard, request, context, extensions
)

if signature_verifier:
Expand Down Expand Up @@ -315,3 +308,43 @@
)

return metadata

def _get_grpc_timeout(
self, context: ClientCallContext | None
) -> float | None:
return context.timeout if context else None

async def _call_grpc(
self,
method: Callable[..., Any],
request: Any,
context: ClientCallContext | None,
extensions: list[str] | None,
**kwargs: Any,
) -> Any:
return await method(
request,
metadata=self._get_grpc_metadata(extensions),
timeout=self._get_grpc_timeout(context),
**kwargs,
)

async def _call_grpc_stream(
self,
method: Callable[..., Any],
request: Any,
context: ClientCallContext | None,
extensions: list[str] | None,
**kwargs: Any,
) -> AsyncGenerator[StreamResponse]:
stream = method(
request,
metadata=self._get_grpc_metadata(extensions),
timeout=self._get_grpc_timeout(context),
**kwargs,
)
while True:
response = await stream.read()
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
break
yield response
7 changes: 5 additions & 2 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,17 @@
self.agent_card,
context,
)
return final_request_payload, final_http_kwargs

def _get_http_args(
self, context: ClientCallContext | None
) -> dict[str, Any] | None:
return context.state.get('http_kwargs') if context else None
) -> dict[str, Any]:
http_kwargs: dict[str, Any] = {}
if context and context.timeout is not None:
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
return http_kwargs

def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:

Check notice on line 467 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (395-405)
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""
code = error_dict.get('code')
message = error_dict.get('message', str(error_dict))
Expand Down
7 changes: 5 additions & 2 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,17 @@
final_http_kwargs = http_kwargs or {}
final_request_payload = request_payload
# TODO: Implement interceptors for other transports
return final_request_payload, final_http_kwargs

def _get_http_args(
self, context: ClientCallContext | None
) -> dict[str, Any] | None:
return context.state.get('http_kwargs') if context else None
) -> dict[str, Any]:
http_kwargs: dict[str, Any] = {}
if context and context.timeout is not None:
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
return http_kwargs

async def _prepare_send_message(

Check notice on line 405 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (457-467)
self,
request: SendMessageRequest,
context: ClientCallContext | None,
Expand Down
34 changes: 34 additions & 0 deletions tests/client/transports/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,32 @@ async def test_send_message_task_response(
assert response.task.id == sample_task.id


@pytest.mark.asyncio
async def test_send_message_with_timeout_context(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_message_send_params: SendMessageRequest,
sample_task: Task,
) -> None:
"""Test send_message passes context timeout to grpc stub."""
from a2a.client.middleware import ClientCallContext

mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
task=sample_task
)
context = ClientCallContext(timeout=12.5)

await grpc_transport.send_message(
sample_message_send_params,
context=context,
)

mock_grpc_stub.SendMessage.assert_awaited_once()
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert 'timeout' in kwargs
assert kwargs['timeout'] == 12.5


@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
@pytest.mark.asyncio
async def test_grpc_mapped_errors(
Expand Down Expand Up @@ -360,6 +386,7 @@ async def test_get_task(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert response.id == sample_task.id

Expand Down Expand Up @@ -389,6 +416,7 @@ async def test_list_tasks(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert result.total_size == 2
assert not result.next_page_token
Expand Down Expand Up @@ -417,6 +445,7 @@ async def test_get_task_with_history(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)


Expand All @@ -443,6 +472,7 @@ async def test_cancel_task(
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3'),
],
timeout=None,
)
assert response.status.state == TaskState.TASK_STATE_CANCELED

Expand Down Expand Up @@ -476,6 +506,7 @@ async def test_create_task_push_notification_config_with_valid_task(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert response.task_id == sample_task_push_notification_config.task_id

Expand Down Expand Up @@ -539,6 +570,7 @@ async def test_get_task_push_notification_config_with_valid_task(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert response.task_id == sample_task_push_notification_config.task_id

Expand Down Expand Up @@ -593,6 +625,7 @@ async def test_list_task_push_notification_configs(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)
assert len(response.configs) == 1
assert response.configs[0].task_id == 'task-1'
Expand Down Expand Up @@ -626,6 +659,7 @@ async def test_delete_task_push_notification_config(
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
),
],
timeout=None,
)


Expand Down
26 changes: 26 additions & 0 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,32 @@ async def test_send_message_json_decode_error(
with pytest.raises(A2AClientError):
await transport.send_message(request)

@pytest.mark.asyncio
async def test_send_message_with_timeout_context(
self, transport, mock_httpx_client
):
"""Test that send_message passes context timeout to build_request."""
from a2a.client.middleware import ClientCallContext

mock_response = MagicMock()
mock_response.json.return_value = {
'jsonrpc': '2.0',
'id': '1',
'result': {},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.send.return_value = mock_response

request = create_send_message_request()
context = ClientCallContext(timeout=15.0)

await transport.send_message(request, context=context)

mock_httpx_client.build_request.assert_called_once()
_, kwargs = mock_httpx_client.build_request.call_args
assert 'timeout' in kwargs
assert kwargs['timeout'] == httpx.Timeout(15.0)


class TestGetTask:
"""Tests for the get_task method."""
Expand Down
Loading
Loading