Skip to content

Commit b348735

Browse files
committed
[DRAFT] feat: Upgrade A2A to v1.0
This updates the SDK to be A2A v1.0 compliant, all types are generated from the v1.0 a2a.proto. JSONRPC/HTTP+JSON transports are converted to use the a2a types encoded using ProtoJSON directly from the generated types.
1 parent 2c544f4 commit b348735

File tree

106 files changed

+1252
-4664
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+1252
-4664
lines changed

buf.gen.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
version: v2
33
inputs:
44
- git_repo: https://github.com/a2aproject/A2A.git
5-
ref: main
5+
ref: transports
66
subdir: specification/grpc
77
managed:
88
enabled: true
@@ -21,11 +21,11 @@ plugins:
2121
# Generate python protobuf related code
2222
# Generates *_pb2.py files, one for each .proto
2323
- remote: buf.build/protocolbuffers/python:v29.3
24-
out: src/a2a/grpc
24+
out: src/a2a/types
2525
# Generate python service code.
2626
# Generates *_pb2_grpc.py
2727
- remote: buf.build/grpc/python
28-
out: src/a2a/grpc
28+
out: src/a2a/types
2929
# Generates *_pb2.pyi files.
3030
- remote: buf.build/protocolbuffers/pyi
31-
out: src/a2a/grpc
31+
out: src/a2a/types

pyproject.toml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies = [
1313
"pydantic>=2.11.3",
1414
"protobuf>=5.29.5",
1515
"google-api-core>=1.26.0",
16+
"json-rpc>=1.15.0",
1617
]
1718

1819
classifiers = [
@@ -114,7 +115,7 @@ explicit = true
114115

115116
[tool.mypy]
116117
plugins = ["pydantic.mypy"]
117-
exclude = ["src/a2a/grpc/"]
118+
exclude = ["src/a2a/types/a2a_pb2\\.py", "src/a2a/types/a2a_pb2_grpc\\.py"]
118119
disable_error_code = [
119120
"import-not-found",
120121
"annotation-unchecked",
@@ -134,7 +135,8 @@ exclude = [
134135
"**/node_modules",
135136
"**/venv",
136137
"**/.venv",
137-
"src/a2a/grpc/",
138+
"src/a2a/types/a2a_pb2.py",
139+
"src/a2a/types/a2a_pb2_grpc.py",
138140
]
139141
reportMissingImports = "none"
140142
reportMissingModuleSource = "none"
@@ -145,7 +147,8 @@ omit = [
145147
"*/tests/*",
146148
"*/site-packages/*",
147149
"*/__init__.py",
148-
"src/a2a/grpc/*",
150+
"src/a2a/types/a2a_pb2.py",
151+
"src/a2a/types/a2a_pb2_grpc.py",
149152
]
150153

151154
[tool.coverage.report]
@@ -257,7 +260,9 @@ exclude = [
257260
"node_modules",
258261
"venv",
259262
"*/migrations/*",
260-
"src/a2a/grpc/**",
263+
"src/a2a/types/a2a_pb2.py",
264+
"src/a2a/types/a2a_pb2.pyi",
265+
"src/a2a/types/a2a_pb2_grpc.py",
261266
"tests/**",
262267
]
263268

@@ -311,7 +316,9 @@ inline-quotes = "single"
311316

312317
[tool.ruff.format]
313318
exclude = [
314-
"src/a2a/grpc/**",
319+
"src/a2a/types/a2a_pb2.py",
320+
"src/a2a/types/a2a_pb2.pyi",
321+
"src/a2a/types/a2a_pb2_grpc.py",
315322
]
316323
docstring-code-format = true
317324
docstring-code-line-length = "dynamic"

src/a2a/client/auth/interceptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from a2a.client.auth.credentials import CredentialService
55
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6-
from a2a.types import (
6+
from a2a.types.a2a_pb2 import (
77
AgentCard,
88
APIKeySecurityScheme,
99
HTTPAuthSecurityScheme,

src/a2a/client/base_client.py

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
1-
from collections.abc import AsyncIterator
1+
from collections.abc import AsyncIterator, AsyncGenerator
22
from typing import Any
33

44
from a2a.client.client import (
55
Client,
66
ClientCallContext,
77
ClientConfig,
8-
ClientEvent,
98
Consumer,
9+
ClientEvent,
1010
)
1111
from a2a.client.client_task_manager import ClientTaskManager
1212
from a2a.client.errors import A2AClientInvalidStateError
1313
from a2a.client.middleware import ClientCallInterceptor
1414
from a2a.client.transports.base import ClientTransport
15-
from a2a.types import (
15+
from a2a.types.a2a_pb2 import (
1616
AgentCard,
17-
GetTaskPushNotificationConfigParams,
1817
Message,
19-
MessageSendConfiguration,
20-
MessageSendParams,
18+
SendMessageConfiguration,
19+
SendMessageRequest,
2120
Task,
2221
TaskArtifactUpdateEvent,
23-
TaskIdParams,
22+
SubscribeToTaskRequest,
23+
CancelTaskRequest,
2424
TaskPushNotificationConfig,
25-
TaskQueryParams,
25+
GetTaskRequest,
2626
TaskStatusUpdateEvent,
27+
StreamResponse,
28+
SetTaskPushNotificationConfigRequest,
29+
GetExtendedAgentCardRequest,
30+
GetTaskPushNotificationConfigRequest,
2731
)
2832

2933

@@ -50,7 +54,7 @@ async def send_message(
5054
context: ClientCallContext | None = None,
5155
request_metadata: dict[str, Any] | None = None,
5256
extensions: list[str] | None = None,
53-
) -> AsyncIterator[ClientEvent | Message]:
57+
) -> AsyncIterator[ClientEvent]:
5458
"""Sends a message to the agent.
5559
5660
This method handles both streaming and non-streaming (polling) interactions
@@ -64,9 +68,9 @@ async def send_message(
6468
extensions: List of extensions to be activated.
6569
6670
Yields:
67-
An async iterator of `ClientEvent` or a final `Message` response.
71+
An async iterator of `ClientEvent`
6872
"""
69-
config = MessageSendConfiguration(
73+
config = SendMessageConfiguration(
7074
accepted_output_modes=self._config.accepted_output_modes,
7175
blocking=not self._config.polling,
7276
push_notification_config=(
@@ -75,67 +79,67 @@ async def send_message(
7579
else None
7680
),
7781
)
78-
params = MessageSendParams(
79-
message=request, configuration=config, metadata=request_metadata
82+
sendMessageRequest = SendMessageRequest(
83+
request=request, configuration=config, metadata=request_metadata
8084
)
8185

8286
if not self._config.streaming or not self._card.capabilities.streaming:
8387
response = await self._transport.send_message(
84-
params, context=context, extensions=extensions
85-
)
86-
result = (
87-
(response, None) if isinstance(response, Task) else response
88+
sendMessageRequest, context=context, extensions=extensions
8889
)
89-
await self.consume(result, self._card)
90-
yield result
91-
return
9290

93-
tracker = ClientTaskManager()
94-
stream = self._transport.send_message_streaming(
95-
params, context=context, extensions=extensions
96-
)
91+
# In non-streaming case we convert to a StreamResponse so that the
92+
# client always sees the same iterator.
93+
stream_response = StreamResponse()
94+
client_event: ClientEvent
95+
if response.HasField("task"):
96+
stream_response.task = response.task
97+
client_event = (stream_response, response.task)
9798

98-
first_event = await anext(stream)
99-
# The response from a server may be either exactly one Message or a
100-
# series of Task updates. Separate out the first message for special
101-
# case handling, which allows us to simplify further stream processing.
102-
if isinstance(first_event, Message):
103-
await self.consume(first_event, self._card)
104-
yield first_event
105-
return
99+
elif response.HasField("message"):
100+
stream_response.msg = response.msg
101+
client_event = (stream_response, None)
106102

107-
yield await self._process_response(tracker, first_event)
103+
await self.consume(client_event, self._card)
104+
yield client_event
105+
return
108106

109-
async for event in stream:
110-
yield await self._process_response(tracker, event)
107+
stream = self._transport.send_message_streaming(
108+
sendMessageRequest, context=context, extensions=extensions
109+
)
110+
async for client_event in self._process_stream(stream):
111+
yield client_event
111112

112-
async def _process_response(
113-
self,
114-
tracker: ClientTaskManager,
115-
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
116-
) -> ClientEvent:
117-
if isinstance(event, Message):
118-
raise A2AClientInvalidStateError(
119-
'received a streamed Message from server after first response; this is not supported'
120-
)
121-
await tracker.process(event)
122-
task = tracker.get_task_or_raise()
123-
update = None if isinstance(event, Task) else event
124-
client_event = (task, update)
125-
await self.consume(client_event, self._card)
126-
return client_event
113+
async def _process_stream(self, stream: AsyncIterator[StreamResponse]) -> AsyncGenerator[ClientEvent]:
114+
tracker = ClientTaskManager()
115+
async for stream_response in stream:
116+
client_event: ClientEvent
117+
# When we get a message in the stream then we don't expect any
118+
# further messages so yield and return
119+
if stream_response.HasField("message"):
120+
client_event = (stream_response, None)
121+
await self.consume(client_event, self._card)
122+
yield client_event
123+
return
124+
125+
# Otherwise track the task / task update then yield to the client
126+
await tracker.process(stream_response)
127+
updated_task = tracker.get_task_or_raise()
128+
client_event = (stream_response, updated_task)
129+
await self.consume(client_event, self._card)
130+
yield client_event
127131

128132
async def get_task(
129133
self,
130-
request: TaskQueryParams,
134+
request: GetTaskRequest,
131135
*,
132136
context: ClientCallContext | None = None,
133137
extensions: list[str] | None = None,
134138
) -> Task:
135139
"""Retrieves the current state and history of a specific task.
136140
137141
Args:
138-
request: The `TaskQueryParams` object specifying the task ID.
142+
request: The `GetTaskRequest` object specifying the task ID.
139143
context: The client call context.
140144
extensions: List of extensions to be activated.
141145
@@ -148,15 +152,15 @@ async def get_task(
148152

149153
async def cancel_task(
150154
self,
151-
request: TaskIdParams,
155+
request: CancelTaskRequest,
152156
*,
153157
context: ClientCallContext | None = None,
154158
extensions: list[str] | None = None,
155159
) -> Task:
156160
"""Requests the agent to cancel a specific task.
157161
158162
Args:
159-
request: The `TaskIdParams` object specifying the task ID.
163+
request: The `CancelTaskRequest` object specifying the task ID.
160164
context: The client call context.
161165
extensions: List of extensions to be activated.
162166
@@ -169,7 +173,7 @@ async def cancel_task(
169173

170174
async def set_task_callback(
171175
self,
172-
request: TaskPushNotificationConfig,
176+
request: SetTaskPushNotificationConfigRequest,
173177
*,
174178
context: ClientCallContext | None = None,
175179
extensions: list[str] | None = None,
@@ -190,7 +194,7 @@ async def set_task_callback(
190194

191195
async def get_task_callback(
192196
self,
193-
request: GetTaskPushNotificationConfigParams,
197+
request: GetTaskPushNotificationConfigRequest,
194198
*,
195199
context: ClientCallContext | None = None,
196200
extensions: list[str] | None = None,
@@ -209,9 +213,9 @@ async def get_task_callback(
209213
request, context=context, extensions=extensions
210214
)
211215

212-
async def resubscribe(
216+
async def subscribe(
213217
self,
214-
request: TaskIdParams,
218+
request: SubscribeToTaskRequest,
215219
*,
216220
context: ClientCallContext | None = None,
217221
extensions: list[str] | None = None,
@@ -240,12 +244,13 @@ async def resubscribe(
240244
# Note: resubscribe can only be called on an existing task. As such,
241245
# we should never see Message updates, despite the typing of the service
242246
# definition indicating it may be possible.
243-
async for event in self._transport.resubscribe(
247+
stream = self._transport.subscribe(
244248
request, context=context, extensions=extensions
245-
):
246-
yield await self._process_response(tracker, event)
249+
)
250+
async for client_event in self._process_stream(stream):
251+
yield client_event
247252

248-
async def get_card(
253+
async def get_extended_agent_card(
249254
self,
250255
*,
251256
context: ClientCallContext | None = None,
@@ -263,7 +268,7 @@ async def get_card(
263268
Returns:
264269
The `AgentCard` for the agent.
265270
"""
266-
card = await self._transport.get_card(
271+
card = await self._transport.get_extended_agent_card(
267272
context=context, extensions=extensions
268273
)
269274
self._card = card

src/a2a/client/card_resolver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
from pydantic import ValidationError
99

10+
from google.protobuf.json_format import ParseDict
1011
from a2a.client.errors import (
1112
A2AClientHTTPError,
1213
A2AClientJSONError,
1314
)
14-
from a2a.types import (
15+
from a2a.types.a2a_pb2 import (
1516
AgentCard,
1617
)
1718
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
@@ -85,7 +86,7 @@ async def get_agent_card(
8586
target_url,
8687
agent_card_data,
8788
)
88-
agent_card = AgentCard.model_validate(agent_card_data)
89+
agent_card = ParseDict(agent_card_data, AgentCard())
8990
except httpx.HTTPStatusError as e:
9091
raise A2AClientHTTPError(
9192
e.response.status_code,

0 commit comments

Comments
 (0)