diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index f4a8d03d..b8697d86 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -1,4 +1,5 @@ from collections.abc import AsyncIterator +from typing import Any from a2a.client.client import ( Client, @@ -47,6 +48,7 @@ async def send_message( request: Message, *, context: ClientCallContext | None = None, + request_metadata: dict[str, Any] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the agent. @@ -57,6 +59,7 @@ async def send_message( Args: request: The message to send to the agent. context: The client call context. + request_metadata: Extensions Metadata attached to the request. Yields: An async iterator of `ClientEvent` or a final `Message` response. @@ -70,7 +73,9 @@ async def send_message( else None ), ) - params = MessageSendParams(message=request, configuration=config) + params = MessageSendParams( + message=request, configuration=config, metadata=request_metadata + ) if not self._config.streaming or not self._card.capabilities.streaming: response = await self._transport.send_message( diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 7cc10423..0e1c4323 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -110,6 +110,7 @@ async def send_message( request: Message, *, context: ClientCallContext | None = None, + request_metadata: dict[str, Any] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the server. diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index d93a2203..f5ab2543 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -73,9 +73,14 @@ async def create_stream(*args, **kwargs): mock_transport.send_message_streaming.return_value = create_stream() - events = [event async for event in base_client.send_message(sample_message)] + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] mock_transport.send_message_streaming.assert_called_once() + assert ( + mock_transport.send_message_streaming.call_args[0][0].metadata == meta + ) assert not mock_transport.send_message.called assert len(events) == 1 assert events[0][0].id == 'task-123' @@ -92,9 +97,12 @@ async def test_send_message_non_streaming( status=TaskStatus(state=TaskState.completed), ) - events = [event async for event in base_client.send_message(sample_message)] + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] mock_transport.send_message.assert_called_once() + assert mock_transport.send_message.call_args[0][0].metadata == meta assert not mock_transport.send_message_streaming.called assert len(events) == 1 assert events[0][0].id == 'task-456'