Skip to content

Commit 754cd73

Browse files
committed
[DRAFT] feat: Upgrade A2A to v1.0
This updates the SDK to be A2A v1.0 compliant, all types are generated from the v1.0 a2a.proto. JSONRPC/HTTP+JSON transports are converted to use the a2a types encoded using ProtoJSON directly from the generated types.
1 parent 2c544f4 commit 754cd73

File tree

105 files changed

+954
-4261
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+954
-4261
lines changed

buf.gen.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
version: v2
33
inputs:
44
- git_repo: https://github.com/a2aproject/A2A.git
5-
ref: main
5+
ref: transports
66
subdir: specification/grpc
77
managed:
88
enabled: true
@@ -21,11 +21,11 @@ plugins:
2121
# Generate python protobuf related code
2222
# Generates *_pb2.py files, one for each .proto
2323
- remote: buf.build/protocolbuffers/python:v29.3
24-
out: src/a2a/grpc
24+
out: src/a2a/types
2525
# Generate python service code.
2626
# Generates *_pb2_grpc.py
2727
- remote: buf.build/grpc/python
28-
out: src/a2a/grpc
28+
out: src/a2a/types
2929
# Generates *_pb2.pyi files.
3030
- remote: buf.build/protocolbuffers/pyi
31-
out: src/a2a/grpc
31+
out: src/a2a/types

pyproject.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ explicit = true
114114

115115
[tool.mypy]
116116
plugins = ["pydantic.mypy"]
117-
exclude = ["src/a2a/grpc/"]
117+
exclude = ["src/a2a/types/a2a_pb2\\.py", "src/a2a/types/a2a_pb2_grpc\\.py"]
118118
disable_error_code = [
119119
"import-not-found",
120120
"annotation-unchecked",
@@ -134,7 +134,8 @@ exclude = [
134134
"**/node_modules",
135135
"**/venv",
136136
"**/.venv",
137-
"src/a2a/grpc/",
137+
"src/a2a/types/a2a_pb2.py",
138+
"src/a2a/types/a2a_pb2_grpc.py",
138139
]
139140
reportMissingImports = "none"
140141
reportMissingModuleSource = "none"
@@ -145,7 +146,8 @@ omit = [
145146
"*/tests/*",
146147
"*/site-packages/*",
147148
"*/__init__.py",
148-
"src/a2a/grpc/*",
149+
"src/a2a/types/a2a_pb2.py",
150+
"src/a2a/types/a2a_pb2_grpc.py",
149151
]
150152

151153
[tool.coverage.report]
@@ -257,7 +259,9 @@ exclude = [
257259
"node_modules",
258260
"venv",
259261
"*/migrations/*",
260-
"src/a2a/grpc/**",
262+
"src/a2a/types/a2a_pb2.py",
263+
"src/a2a/types/a2a_pb2.pyi",
264+
"src/a2a/types/a2a_pb2_grpc.py",
261265
"tests/**",
262266
]
263267

@@ -311,7 +315,9 @@ inline-quotes = "single"
311315

312316
[tool.ruff.format]
313317
exclude = [
314-
"src/a2a/grpc/**",
318+
"src/a2a/types/a2a_pb2.py",
319+
"src/a2a/types/a2a_pb2.pyi",
320+
"src/a2a/types/a2a_pb2_grpc.py",
315321
]
316322
docstring-code-format = true
317323
docstring-code-line-length = "dynamic"

src/a2a/client/auth/interceptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from a2a.client.auth.credentials import CredentialService
55
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6-
from a2a.types import (
6+
from a2a.types.a2a_pb2 import (
77
AgentCard,
88
APIKeySecurityScheme,
99
HTTPAuthSecurityScheme,

src/a2a/client/base_client.py

Lines changed: 52 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
from collections.abc import AsyncIterator
1+
from collections.abc import AsyncIterator, AsyncGenerator
22
from typing import Any
33

44
from a2a.client.client import (
55
Client,
66
ClientCallContext,
77
ClientConfig,
8-
ClientEvent,
98
Consumer,
9+
ClientEvent,
1010
)
1111
from a2a.client.client_task_manager import ClientTaskManager
1212
from a2a.client.errors import A2AClientInvalidStateError
1313
from a2a.client.middleware import ClientCallInterceptor
1414
from a2a.client.transports.base import ClientTransport
15-
from a2a.types import (
15+
from a2a.types.a2a_pb2 import (
1616
AgentCard,
1717
GetTaskPushNotificationConfigParams,
1818
Message,
19-
MessageSendConfiguration,
20-
MessageSendParams,
19+
SendMessageConfiguration,
20+
SendMessageRequest,
2121
Task,
2222
TaskArtifactUpdateEvent,
23-
TaskIdParams,
23+
SubscribeToTaskRequest,
24+
CancelTaskRequest,
2425
TaskPushNotificationConfig,
25-
TaskQueryParams,
26+
GetTaskRequest,
2627
TaskStatusUpdateEvent,
28+
StreamResponse,
2729
)
2830

2931

@@ -50,7 +52,7 @@ async def send_message(
5052
context: ClientCallContext | None = None,
5153
request_metadata: dict[str, Any] | None = None,
5254
extensions: list[str] | None = None,
53-
) -> AsyncIterator[ClientEvent | Message]:
55+
) -> AsyncIterator[StreamResponse]:
5456
"""Sends a message to the agent.
5557
5658
This method handles both streaming and non-streaming (polling) interactions
@@ -64,9 +66,9 @@ async def send_message(
6466
extensions: List of extensions to be activated.
6567
6668
Yields:
67-
An async iterator of `ClientEvent` or a final `Message` response.
69+
An async iterator of `ClientEvent`
6870
"""
69-
config = MessageSendConfiguration(
71+
config = SendMessageConfiguration(
7072
accepted_output_modes=self._config.accepted_output_modes,
7173
blocking=not self._config.polling,
7274
push_notification_config=(
@@ -75,67 +77,63 @@ async def send_message(
7577
else None
7678
),
7779
)
78-
params = MessageSendParams(
79-
message=request, configuration=config, metadata=request_metadata
80+
params = SendMessageRequest(
81+
request=request, configuration=config, metadata=request_metadata
8082
)
8183

8284
if not self._config.streaming or not self._card.capabilities.streaming:
8385
response = await self._transport.send_message(
8486
params, context=context, extensions=extensions
8587
)
86-
result = (
87-
(response, None) if isinstance(response, Task) else response
88-
)
89-
await self.consume(result, self._card)
90-
yield result
88+
89+
# In non-streaming case we convert to a StreamResponse so that the
90+
# client always sees the same iterator.
91+
stream_response = StreamResponse()
92+
if response.HasField("task"):
93+
stream_response.task = response.task
94+
client_event = (stream_response, response.task)
95+
96+
if response.HasField("message"):
97+
stream_response.message = response.message
98+
client_event = (stream_response, None)
99+
100+
await self.consume(client_event, self._card)
101+
yield client_event
91102
return
92103

93-
tracker = ClientTaskManager()
94104
stream = self._transport.send_message_streaming(
95105
params, context=context, extensions=extensions
96106
)
107+
async for client_event in self._process_stream(stream):
108+
yield client_event
97109

98-
first_event = await anext(stream)
99-
# The response from a server may be either exactly one Message or a
100-
# series of Task updates. Separate out the first message for special
101-
# case handling, which allows us to simplify further stream processing.
102-
if isinstance(first_event, Message):
103-
await self.consume(first_event, self._card)
104-
yield first_event
105-
return
106-
107-
yield await self._process_response(tracker, first_event)
110+
async def _process_stream(self, stream: AsyncIterator[StreamResponse]) -> AsyncGenerator[ClientEvent]:
111+
tracker = ClientTaskManager()
112+
async for stream_response in stream:
113+
await self.consume(stream_response)
108114

109-
async for event in stream:
110-
yield await self._process_response(tracker, event)
115+
# When we get a message in the stream then we don't expect any
116+
# further messages so yield and return
117+
if stream_response.HasField("message"):
118+
yield stream_response, None
119+
return
111120

112-
async def _process_response(
113-
self,
114-
tracker: ClientTaskManager,
115-
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
116-
) -> ClientEvent:
117-
if isinstance(event, Message):
118-
raise A2AClientInvalidStateError(
119-
'received a streamed Message from server after first response; this is not supported'
120-
)
121-
await tracker.process(event)
122-
task = tracker.get_task_or_raise()
123-
update = None if isinstance(event, Task) else event
124-
client_event = (task, update)
125-
await self.consume(client_event, self._card)
126-
return client_event
121+
# Otherwise track the task / task update then yield to the client
122+
tracker.process(stream_response)
123+
updated_task = tracker.get_task_or_raise()
124+
yield stream_response, updated_task
127125

128126
async def get_task(
129127
self,
130-
request: TaskQueryParams,
128+
request: GetTaskRequest,
131129
*,
132130
context: ClientCallContext | None = None,
133131
extensions: list[str] | None = None,
134132
) -> Task:
135133
"""Retrieves the current state and history of a specific task.
136134
137135
Args:
138-
request: The `TaskQueryParams` object specifying the task ID.
136+
request: The `GetTaskRequest` object specifying the task ID.
139137
context: The client call context.
140138
extensions: List of extensions to be activated.
141139
@@ -148,15 +146,15 @@ async def get_task(
148146

149147
async def cancel_task(
150148
self,
151-
request: TaskIdParams,
149+
request: CancelTaskRequest,
152150
*,
153151
context: ClientCallContext | None = None,
154152
extensions: list[str] | None = None,
155153
) -> Task:
156154
"""Requests the agent to cancel a specific task.
157155
158156
Args:
159-
request: The `TaskIdParams` object specifying the task ID.
157+
request: The `CancelTaskRequest` object specifying the task ID.
160158
context: The client call context.
161159
extensions: List of extensions to be activated.
162160
@@ -211,7 +209,7 @@ async def get_task_callback(
211209

212210
async def resubscribe(
213211
self,
214-
request: TaskIdParams,
212+
request: SubscribeToTaskRequest,
215213
*,
216214
context: ClientCallContext | None = None,
217215
extensions: list[str] | None = None,
@@ -240,10 +238,11 @@ async def resubscribe(
240238
# Note: resubscribe can only be called on an existing task. As such,
241239
# we should never see Message updates, despite the typing of the service
242240
# definition indicating it may be possible.
243-
async for event in self._transport.resubscribe(
241+
stream = self._transport.resubscribe(
244242
request, context=context, extensions=extensions
245-
):
246-
yield await self._process_response(tracker, event)
243+
)
244+
async for client_event in self._process_stream(stream):
245+
yield client_event
247246

248247
async def get_card(
249248
self,

src/a2a/client/card_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
A2AClientHTTPError,
1212
A2AClientJSONError,
1313
)
14-
from a2a.types import (
14+
from a2a.types.a2a_pb2 import (
1515
AgentCard,
1616
)
1717
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH

src/a2a/client/client.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1111
from a2a.client.optionals import Channel
12-
from a2a.types import (
12+
from a2a.types.a2a_pb2 import (
1313
AgentCard,
1414
GetTaskPushNotificationConfigParams,
1515
Message,
@@ -21,6 +21,7 @@
2121
TaskQueryParams,
2222
TaskStatusUpdateEvent,
2323
TransportProtocol,
24+
StreamResponse,
2425
)
2526

2627

@@ -71,14 +72,11 @@ class ClientConfig:
7172
"""A list of extension URIs the client supports."""
7273

7374

74-
UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None
75-
# Alias for emitted events from client
76-
ClientEvent = tuple[Task, UpdateEvent]
75+
ClientEvent = [StreamResponse, Task | None]
76+
7777
# Alias for an event consuming callback. It takes either a (task, update) pair
7878
# or a message as well as the agent card for the agent this came from.
79-
Consumer = Callable[
80-
[ClientEvent | Message, AgentCard], Coroutine[None, Any, Any]
81-
]
79+
Consumer = Callable[ClientEvent, Coroutine[None, Any, Any]]
8280

8381

8482
class Client(ABC):
@@ -115,7 +113,7 @@ async def send_message(
115113
context: ClientCallContext | None = None,
116114
request_metadata: dict[str, Any] | None = None,
117115
extensions: list[str] | None = None,
118-
) -> AsyncIterator[ClientEvent | Message]:
116+
) -> AsyncIterator[StreamResponse]:
119117
"""Sends a message to the server.
120118
121119
This will automatically use the streaming or non-streaming approach
@@ -174,7 +172,7 @@ async def resubscribe(
174172
*,
175173
context: ClientCallContext | None = None,
176174
extensions: list[str] | None = None,
177-
) -> AsyncIterator[ClientEvent]:
175+
) -> AsyncIterator[StreamResponse]:
178176
"""Resubscribes to a task's event stream."""
179177
return
180178
yield

0 commit comments

Comments
 (0)