Skip to content

Commit f81206b

Browse files
committed
add metadata to send message request
1 parent b2e3a29 commit f81206b

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

src/a2a/client/base_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import AsyncIterator
2+
from typing import Any
23

34
from a2a.client.client import (
45
Client,
@@ -47,6 +48,7 @@ async def send_message(
4748
request: Message,
4849
*,
4950
context: ClientCallContext | None = None,
51+
request_metadata: dict[str, Any] | None = None,
5052
) -> AsyncIterator[ClientEvent | Message]:
5153
"""Sends a message to the agent.
5254
@@ -57,6 +59,7 @@ async def send_message(
5759
Args:
5860
request: The message to send to the agent.
5961
context: The client call context.
62+
request_metadata: Extensions Metadata attached to the request.
6063
6164
Yields:
6265
An async iterator of `ClientEvent` or a final `Message` response.
@@ -70,7 +73,7 @@ async def send_message(
7073
else None
7174
),
7275
)
73-
params = MessageSendParams(message=request, configuration=config)
76+
params = MessageSendParams(message=request, configuration=config, metadata=request_metadata)
7477

7578
if not self._config.streaming or not self._card.capabilities.streaming:
7679
response = await self._transport.send_message(

src/a2a/client/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def send_message(
110110
request: Message,
111111
*,
112112
context: ClientCallContext | None = None,
113+
request_metadata: dict[str, Any] | None = None,
113114
) -> AsyncIterator[ClientEvent | Message]:
114115
"""Sends a message to the server.
115116

tests/client/test_base_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,12 @@ async def create_stream(*args, **kwargs):
7171

7272
mock_transport.send_message_streaming.return_value = create_stream()
7373

74-
events = [event async for event in base_client.send_message(sample_message)]
74+
meta = {"test":1}
75+
stream = base_client.send_message(sample_message, request_metadata=meta)
76+
events = [event async for event in stream]
7577

7678
mock_transport.send_message_streaming.assert_called_once()
79+
assert mock_transport.send_message_streaming.call_args[0][0].metadata == meta
7780
assert not mock_transport.send_message.called
7881
assert len(events) == 1
7982
assert events[0][0].id == 'task-123'
@@ -90,9 +93,12 @@ async def test_send_message_non_streaming(
9093
status=TaskStatus(state=TaskState.completed),
9194
)
9295

93-
events = [event async for event in base_client.send_message(sample_message)]
96+
meta = {"test":1}
97+
stream = base_client.send_message(sample_message, request_metadata=meta)
98+
events = [event async for event in stream]
9499

95100
mock_transport.send_message.assert_called_once()
101+
assert mock_transport.send_message.call_args[0][0].metadata == meta
96102
assert not mock_transport.send_message_streaming.called
97103
assert len(events) == 1
98104
assert events[0][0].id == 'task-456'

0 commit comments

Comments
 (0)