Skip to content

Commit 28b1d53

Browse files
committed
feat: move common functions for managing HTTP extension headers to utility.py Add extensions feature to grpc.
1 parent caba0a2 commit 28b1d53

File tree

9 files changed

+489
-262
lines changed

9 files changed

+489
-262
lines changed

src/a2a/client/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
self,
9494
consumers: list[Consumer] | None = None,
9595
middleware: list[ClientCallInterceptor] | None = None,
96+
# iva todo add optional extensions- it can override value from the config, if it is provided
9697
):
9798
"""Initializes the client with consumers and middleware.
9899
@@ -113,6 +114,8 @@ async def send_message(
113114
request: Message,
114115
*,
115116
context: ClientCallContext | None = None,
117+
# iva todo add optional extensions- it can override value from the config, if it is provided
118+
# and to the other ones as well
116119
) -> AsyncIterator[ClientEvent | Message]:
117120
"""Sends a message to the server.
118121

src/a2a/client/transports/grpc.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
from collections.abc import AsyncGenerator
4+
from typing import Any
45

56

67
try:
@@ -16,6 +17,7 @@
1617
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1718
from a2a.client.optionals import Channel
1819
from a2a.client.transports.base import ClientTransport
20+
from a2a.client.transports.utils import update_extension_header
1921
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
2022
from a2a.types import (
2123
AgentCard,
@@ -44,6 +46,7 @@ def __init__(
4446
self,
4547
channel: Channel,
4648
agent_card: AgentCard | None,
49+
extensions: list[str] | None = None,
4750
):
4851
"""Initializes the GrpcTransport."""
4952
self.agent_card = agent_card
@@ -54,6 +57,25 @@ def __init__(
5457
if agent_card
5558
else True
5659
)
60+
self.extensions = extensions
61+
62+
def _get_metadata(
63+
self, context: ClientCallContext | None
64+
) -> list[tuple[str, str]]:
65+
http_kwargs: dict[str, Any] = {}
66+
if context and context.state.get('grpc_metadata'):
67+
# Convert existing metadata to headers format for update_extension_header
68+
http_kwargs['headers'] = {
69+
k: v for k, v in context.state['grpc_metadata']
70+
}
71+
72+
updated_kwargs = update_extension_header(http_kwargs, self.extensions)
73+
74+
metadata = []
75+
if 'headers' in updated_kwargs:
76+
metadata.extend(updated_kwargs['headers'].items())
77+
78+
return metadata
5779

5880
@classmethod
5981
def create(
@@ -66,10 +88,7 @@ def create(
6688
"""Creates a gRPC transport for the A2A client."""
6789
if config.grpc_channel_factory is None:
6890
raise ValueError('grpc_channel_factory is required when using gRPC')
69-
return cls(
70-
config.grpc_channel_factory(url),
71-
card,
72-
)
91+
return cls(config.grpc_channel_factory(url), card, config.extensions)
7392

7493
async def send_message(
7594
self,
@@ -85,7 +104,8 @@ async def send_message(
85104
request.configuration
86105
),
87106
metadata=proto_utils.ToProto.metadata(request.metadata),
88-
)
107+
),
108+
metadata=self._get_metadata(context),
89109
)
90110
if response.HasField('task'):
91111
return proto_utils.FromProto.task(response.task)
@@ -107,7 +127,8 @@ async def send_message_streaming(
107127
request.configuration
108128
),
109129
metadata=proto_utils.ToProto.metadata(request.metadata),
110-
)
130+
),
131+
metadata=self._get_metadata(context),
111132
)
112133
while True:
113134
response = await stream.read()
@@ -122,7 +143,8 @@ async def resubscribe(
122143
]:
123144
"""Reconnects to get task updates."""
124145
stream = self.stub.TaskSubscription(
125-
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}')
146+
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'),
147+
metadata=self._get_metadata(context),
126148
)
127149
while True:
128150
response = await stream.read()
@@ -141,7 +163,8 @@ async def get_task(
141163
a2a_pb2.GetTaskRequest(
142164
name=f'tasks/{request.id}',
143165
history_length=request.history_length,
144-
)
166+
),
167+
metadata=self._get_metadata(context),
145168
)
146169
return proto_utils.FromProto.task(task)
147170

@@ -153,7 +176,8 @@ async def cancel_task(
153176
) -> Task:
154177
"""Requests the agent to cancel a specific task."""
155178
task = await self.stub.CancelTask(
156-
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
179+
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'),
180+
metadata=self._get_metadata(context),
157181
)
158182
return proto_utils.FromProto.task(task)
159183

@@ -171,7 +195,8 @@ async def set_task_callback(
171195
config=proto_utils.ToProto.task_push_notification_config(
172196
request
173197
),
174-
)
198+
),
199+
metadata=self._get_metadata(context),
175200
)
176201
return proto_utils.FromProto.task_push_notification_config(config)
177202

@@ -185,7 +210,8 @@ async def get_task_callback(
185210
config = await self.stub.GetTaskPushNotificationConfig(
186211
a2a_pb2.GetTaskPushNotificationConfigRequest(
187212
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
188-
)
213+
),
214+
metadata=self._get_metadata(context),
189215
)
190216
return proto_utils.FromProto.task_push_notification_config(config)
191217

@@ -203,6 +229,7 @@ async def get_card(
203229

204230
card_pb = await self.stub.GetAgentCard(
205231
a2a_pb2.GetAgentCardRequest(),
232+
metadata=self._get_metadata(context), # probaby not needed
206233
)
207234
card = proto_utils.FromProto.agent_card(card_pb)
208235
self.agent_card = card

src/a2a/client/transports/jsonrpc.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
2020
from a2a.client.transports.base import ClientTransport
21-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
21+
from a2a.client.transports.utils import get_http_args, update_extension_header
2222
from a2a.types import (
2323
AgentCard,
2424
CancelTaskRequest,
@@ -83,25 +83,6 @@ def __init__(
8383
)
8484
self.extensions = extensions
8585

86-
def _update_extension_header(
87-
self, http_kwargs: dict[str, Any]
88-
) -> dict[str, Any]:
89-
if not self.extensions:
90-
return http_kwargs
91-
92-
headers = http_kwargs.setdefault('headers', {})
93-
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
94-
95-
existing_extensions = [
96-
e.strip() for e in existing_extensions_str.split(',') if e.strip()
97-
]
98-
99-
all_extensions = set(self.extensions)
100-
all_extensions.update(existing_extensions)
101-
102-
headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions))
103-
return http_kwargs
104-
10586
async def _apply_interceptors(
10687
self,
10788
method_name: str,
@@ -125,11 +106,6 @@ async def _apply_interceptors(
125106
)
126107
return final_request_payload, final_http_kwargs
127108

128-
def _get_http_args(
129-
self, context: ClientCallContext | None
130-
) -> dict[str, Any] | None:
131-
return context.state.get('http_kwargs') if context else None
132-
133109
async def send_message(
134110
self,
135111
request: MessageSendParams,
@@ -141,10 +117,12 @@ async def send_message(
141117
payload, modified_kwargs = await self._apply_interceptors(
142118
'message/send',
143119
rpc_request.model_dump(mode='json', exclude_none=True),
144-
self._get_http_args(context),
120+
get_http_args(context),
145121
context,
146122
)
147-
modified_kwargs = self._update_extension_header(modified_kwargs)
123+
modified_kwargs = update_extension_header(
124+
modified_kwargs, self.extensions
125+
)
148126
response_data = await self._send_request(payload, modified_kwargs)
149127
response = SendMessageResponse.model_validate(response_data)
150128
if isinstance(response.root, JSONRPCErrorResponse):
@@ -166,11 +144,13 @@ async def send_message_streaming(
166144
payload, modified_kwargs = await self._apply_interceptors(
167145
'message/stream',
168146
rpc_request.model_dump(mode='json', exclude_none=True),
169-
self._get_http_args(context),
147+
get_http_args(context),
170148
context,
171149
)
172150

173-
modified_kwargs = self._update_extension_header(modified_kwargs)
151+
modified_kwargs = update_extension_header(
152+
modified_kwargs, self.extensions
153+
)
174154
modified_kwargs.setdefault(
175155
'timeout', self.httpx_client.timeout.as_dict().get('read', None)
176156
)
@@ -237,7 +217,7 @@ async def get_task(
237217
payload, modified_kwargs = await self._apply_interceptors(
238218
'tasks/get',
239219
rpc_request.model_dump(mode='json', exclude_none=True),
240-
self._get_http_args(context),
220+
get_http_args(context),
241221
context,
242222
)
243223
response_data = await self._send_request(payload, modified_kwargs)
@@ -257,7 +237,7 @@ async def cancel_task(
257237
payload, modified_kwargs = await self._apply_interceptors(
258238
'tasks/cancel',
259239
rpc_request.model_dump(mode='json', exclude_none=True),
260-
self._get_http_args(context),
240+
get_http_args(context),
261241
context,
262242
)
263243
response_data = await self._send_request(payload, modified_kwargs)
@@ -279,7 +259,7 @@ async def set_task_callback(
279259
payload, modified_kwargs = await self._apply_interceptors(
280260
'tasks/pushNotificationConfig/set',
281261
rpc_request.model_dump(mode='json', exclude_none=True),
282-
self._get_http_args(context),
262+
get_http_args(context),
283263
context,
284264
)
285265
response_data = await self._send_request(payload, modified_kwargs)
@@ -303,7 +283,7 @@ async def get_task_callback(
303283
payload, modified_kwargs = await self._apply_interceptors(
304284
'tasks/pushNotificationConfig/get',
305285
rpc_request.model_dump(mode='json', exclude_none=True),
306-
self._get_http_args(context),
286+
get_http_args(context),
307287
context,
308288
)
309289
response_data = await self._send_request(payload, modified_kwargs)
@@ -327,7 +307,7 @@ async def resubscribe(
327307
payload, modified_kwargs = await self._apply_interceptors(
328308
'tasks/resubscribe',
329309
rpc_request.model_dump(mode='json', exclude_none=True),
330-
self._get_http_args(context),
310+
get_http_args(context),
331311
context,
332312
)
333313

@@ -369,7 +349,7 @@ async def get_card(
369349
if not card:
370350
resolver = A2ACardResolver(self.httpx_client, self.url)
371351
card = await resolver.get_agent_card(
372-
http_kwargs=self._get_http_args(context)
352+
http_kwargs=get_http_args(context)
373353
)
374354
self._needs_extended_card = (
375355
card.supports_authenticated_extended_card
@@ -383,7 +363,7 @@ async def get_card(
383363
payload, modified_kwargs = await self._apply_interceptors(
384364
request.method,
385365
request.model_dump(mode='json', exclude_none=True),
386-
self._get_http_args(context),
366+
get_http_args(context),
387367
context,
388368
)
389369

0 commit comments

Comments
 (0)