Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import AsyncIterator
from typing import Any

from a2a.client.client import (
Client,
Expand Down Expand Up @@ -47,6 +48,7 @@
request: Message,
*,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
) -> AsyncIterator[ClientEvent | Message]:
"""Sends a message to the agent.
Expand All @@ -57,6 +59,7 @@
Args:
request: The message to send to the agent.
context: The client call context.
request_metadata: Extensions Metadata attached to the request.
Yields:
An async iterator of `ClientEvent` or a final `Message` response.
Expand All @@ -70,7 +73,7 @@
else None
),
)
params = MessageSendParams(message=request, configuration=config)
params = MessageSendParams(message=request, configuration=config, metadata=request_metadata)

if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
Expand Down
1 change: 1 addition & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def send_message(
request: Message,
*,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
) -> AsyncIterator[ClientEvent | Message]:
"""Sends a message to the server.
Expand Down
10 changes: 8 additions & 2 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from unittest.mock import AsyncMock, MagicMock

import pytest
Expand Down Expand Up @@ -73,9 +73,12 @@

mock_transport.send_message_streaming.return_value = create_stream()

events = [event async for event in base_client.send_message(sample_message)]
meta = {"test":1}
stream = base_client.send_message(sample_message, request_metadata=meta)
events = [event async for event in stream]

mock_transport.send_message_streaming.assert_called_once()
assert mock_transport.send_message_streaming.call_args[0][0].metadata == meta
assert not mock_transport.send_message.called
assert len(events) == 1
assert events[0][0].id == 'task-123'
Expand All @@ -92,9 +95,12 @@
status=TaskStatus(state=TaskState.completed),
)

events = [event async for event in base_client.send_message(sample_message)]
meta = {"test":1}
stream = base_client.send_message(sample_message, request_metadata=meta)
events = [event async for event in stream]

mock_transport.send_message.assert_called_once()
assert mock_transport.send_message.call_args[0][0].metadata == meta
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
assert events[0][0].id == 'task-456'
Expand Down
Loading