Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
AAgent
ACard
AClient
ACMRTUXB
Expand Down
67 changes: 67 additions & 0 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import logging

from collections.abc import Callable
from typing import Any

import httpx

from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client import Client, ClientConfig, Consumer
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
Expand Down Expand Up @@ -101,6 +103,71 @@ def _register_defaults(
GrpcTransport.create,
)

@classmethod
async def connect( # noqa: PLR0913
cls,
agent: str | AgentCard,
client_config: ClientConfig | None = None,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
relative_card_path: str | None = None,
resolver_http_kwargs: dict[str, Any] | None = None,
extra_transports: dict[str, TransportProducer] | None = None,
) -> Client:
"""Convenience method for constructing a client.

Constructs a client that connects to the specified agent. Note that
creating multiple clients via this method is less efficient than
constructing an instance of ClientFactory and reusing that.

.. code-block:: python

# This will search for an AgentCard at /.well-known/agent-card.json
my_agent_url = 'https://travel.agents.example.com'
client = await ClientFactory.connect(my_agent_url)


Args:
agent: The base URL of the agent, or the AgentCard to connect to.
client_config: The ClientConfig to use when connecting to the agent.
consumers: A list of `Consumer` methods to pass responses to.
interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
relative_card_path: If the agent field is a URL, this value is used as
the relative path when resolving the agent card. See
A2AAgentCardResolver.get_agent_card for more details.
resolver_http_kwargs: Dictionary of arguments to provide to the httpx
client when resolving the agent card. This value is provided to
A2AAgentCardResolver.get_agent_card as the http_kwargs parameter.
extra_transports: Additional transport protocols to enable when
constructing the client.

Returns:
A `Client` object.
"""
client_config = client_config or ClientConfig()
if isinstance(agent, str):
if not client_config.httpx_client:
async with httpx.AsyncClient() as client:
resolver = A2ACardResolver(client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
)
else:
resolver = A2ACardResolver(client_config.httpx_client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
)
else:
card = agent
factory = cls(client_config)
for label, generator in (extra_transports or {}).items():
factory.register(label, generator)
return factory.create(card, consumers, interceptors)

def register(self, label: str, generator: TransportProducer) -> None:
"""Register a new transport producer for a given transport label."""
self._registry[label] = generator
Expand Down
157 changes: 157 additions & 0 deletions tests/client/test_client_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the ClientFactory."""

from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest

Expand Down Expand Up @@ -103,3 +105,158 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard):
factory = ClientFactory(config)
with pytest.raises(ValueError, match='no compatible transports found'):
factory.create(base_agent_card)


@pytest.mark.asyncio
async def test_client_factory_connect_with_agent_card(
base_agent_card: AgentCard,
):
"""Verify that connect works correctly when provided with an AgentCard."""
client = await ClientFactory.connect(base_agent_card)
assert isinstance(client._transport, JsonRpcTransport)
assert client._transport.url == 'http://primary-url.com'


@pytest.mark.asyncio
async def test_client_factory_connect_with_url(base_agent_card: AgentCard):
"""Verify that connect works correctly when provided with a URL."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
client = await ClientFactory.connect(agent_url)

mock_resolver.assert_called_once()
assert mock_resolver.call_args[0][1] == agent_url
mock_resolver.return_value.get_agent_card.assert_awaited_once()

assert isinstance(client._transport, JsonRpcTransport)
assert client._transport.url == 'http://primary-url.com'


@pytest.mark.asyncio
async def test_client_factory_connect_with_url_and_client_config(
base_agent_card: AgentCard,
):
"""Verify connect with a URL and a pre-configured httpx client."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
mock_httpx_client = httpx.AsyncClient()
config = ClientConfig(httpx_client=mock_httpx_client)

client = await ClientFactory.connect(agent_url, client_config=config)

mock_resolver.assert_called_once_with(mock_httpx_client, agent_url)
mock_resolver.return_value.get_agent_card.assert_awaited_once()

assert isinstance(client._transport, JsonRpcTransport)
assert client._transport.url == 'http://primary-url.com'


@pytest.mark.asyncio
async def test_client_factory_connect_with_resolver_args(
base_agent_card: AgentCard,
):
"""Verify connect passes resolver arguments correctly."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
relative_path = '/card'
http_kwargs = {'headers': {'X-Test': 'true'}}

# The resolver args are only passed if an httpx_client is provided in config
config = ClientConfig(httpx_client=httpx.AsyncClient())

await ClientFactory.connect(
agent_url,
client_config=config,
relative_card_path=relative_path,
resolver_http_kwargs=http_kwargs,
)

mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
relative_card_path=relative_path,
http_kwargs=http_kwargs,
)


@pytest.mark.asyncio
async def test_client_factory_connect_resolver_args_without_client(
base_agent_card: AgentCard,
):
"""Verify resolver args are ignored if no httpx_client is provided."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
relative_path = '/card'
http_kwargs = {'headers': {'X-Test': 'true'}}

await ClientFactory.connect(
agent_url,
relative_card_path=relative_path,
resolver_http_kwargs=http_kwargs,
)

mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
relative_card_path=relative_path,
http_kwargs=http_kwargs,
)


@pytest.mark.asyncio
async def test_client_factory_connect_with_extra_transports(
base_agent_card: AgentCard,
):
"""Verify that connect can register and use extra transports."""

class CustomTransport:
pass

def custom_transport_producer(*args, **kwargs):
return CustomTransport()

base_agent_card.preferred_transport = 'custom'
base_agent_card.url = 'custom://foo'

config = ClientConfig(supported_transports=['custom'])

client = await ClientFactory.connect(
base_agent_card,
client_config=config,
extra_transports={'custom': custom_transport_producer},
)

assert isinstance(client._transport, CustomTransport)


@pytest.mark.asyncio
async def test_client_factory_connect_with_consumers_and_interceptors(
base_agent_card: AgentCard,
):
"""Verify consumers and interceptors are passed through correctly."""
consumer1 = MagicMock()
interceptor1 = MagicMock()

with patch('a2a.client.client_factory.BaseClient') as mock_base_client:
await ClientFactory.connect(
base_agent_card,
consumers=[consumer1],
interceptors=[interceptor1],
)

mock_base_client.assert_called_once()
call_args = mock_base_client.call_args[0]
assert call_args[3] == [consumer1]
assert call_args[4] == [interceptor1]
Loading