Skip to content

Commit 2e8fbc4

Browse files
authored
fix: Add validation for transport types in ClientFactory (#396)
### Description This update introduces upfront validation for transport protocols within the `ClientFactory`. Previously, if an invalid or misspelled transport type was provided in the `ClientConfig`, it would be silently ignored. This could lead to confusing `ValueError('no compatible transports found.')` errors later on, making it difficult to debug misconfigurations. ### Changes * **`src/a2a/client/client_factory.py`**: Added a validation loop at the beginning of the `create` method to check all transport types provided in the `ClientConfig`. * **`tests/client/test_client_factory.py`**: Added a new unit test, `test_client_factory_invalid_transport_in_config`, to confirm that a `ValueError` is raised when an invalid transport string is used.
1 parent a96df5c commit 2e8fbc4

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

src/a2a/client/client_factory.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _register_defaults(
9191
if GrpcTransport is None:
9292
raise ImportError(
9393
'To use GrpcClient, its dependencies must be installed. '
94-
'You can install them with \'pip install "a2a-sdk[grpc]"\''
94+
'You can install them with \'pip install "a2a-sdk[grpc]"\'',
9595
)
9696
self.register(
9797
TransportProtocol.grpc,
@@ -124,6 +124,20 @@ def create(
124124
If there is no valid matching of the client configuration with the
125125
server configuration, a `ValueError` is raised.
126126
"""
127+
valid_transports = {member.value for member in TransportProtocol}
128+
configured_transports = set(self._config.supported_transports)
129+
130+
invalid_transports = configured_transports.difference(valid_transports)
131+
if invalid_transports:
132+
invalid_str = ', '.join(
133+
sorted(f"'{t}'" for t in invalid_transports)
134+
)
135+
valid_str = ', '.join(sorted(valid_transports))
136+
raise ValueError(
137+
f'Unsupported transport type(s) in ClientConfig: {invalid_str}. '
138+
f'Valid types are: {valid_str}'
139+
)
140+
127141
server_preferred = card.preferred_transport or TransportProtocol.jsonrpc
128142
server_set = {server_preferred: card.url}
129143
if card.additional_interfaces:

tests/client/test_client_factory.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,32 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard):
103103
factory = ClientFactory(config)
104104
with pytest.raises(ValueError, match='no compatible transports found'):
105105
factory.create(base_agent_card)
106+
107+
108+
@pytest.mark.parametrize(
109+
('invalid_transports', 'expected_match'),
110+
[
111+
(
112+
['invalid-transport'],
113+
"Unsupported transport type\\(s\\) in ClientConfig: 'invalid-transport'",
114+
),
115+
(
116+
['invalid-1', 'another-bad-one'],
117+
"Unsupported transport type\\(s\\) in ClientConfig: 'another-bad-one', 'invalid-1'",
118+
),
119+
],
120+
)
121+
def test_client_factory_invalid_transport_in_config(
122+
base_agent_card: AgentCard, invalid_transports, expected_match
123+
):
124+
"""Verify that the factory raises an error for unknown transport types."""
125+
config = ClientConfig(
126+
httpx_client=httpx.AsyncClient(),
127+
supported_transports=[
128+
TransportProtocol.jsonrpc,
129+
*invalid_transports,
130+
],
131+
)
132+
factory = ClientFactory(config)
133+
with pytest.raises(ValueError, match=expected_match):
134+
factory.create(base_agent_card)

0 commit comments

Comments
 (0)