Skip to content

Commit 6e86e78

Browse files
authored
Merge branch 'main' into poc/agent-catalog
2 parents 55ae0e3 + 0c9df12 commit 6e86e78

File tree

9 files changed

+247
-72
lines changed

9 files changed

+247
-72
lines changed

src/a2a/client/client.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,37 @@
11
import json
22
import logging
3+
34
from collections.abc import AsyncGenerator
45
from typing import Any
56
from uuid import uuid4
67

78
import httpx
9+
810
from httpx_sse import SSEError, aconnect_sse
911
from pydantic import ValidationError
1012

1113
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
12-
from a2a.types import (AgentCard, CancelTaskRequest, CancelTaskResponse,
13-
GetTaskPushNotificationConfigRequest,
14-
GetTaskPushNotificationConfigResponse, GetTaskRequest,
15-
GetTaskResponse, SendMessageRequest,
16-
SendMessageResponse, SendStreamingMessageRequest,
17-
SendStreamingMessageResponse,
18-
SetTaskPushNotificationConfigRequest,
19-
SetTaskPushNotificationConfigResponse)
14+
from a2a.types import (
15+
AgentCard,
16+
CancelTaskRequest,
17+
CancelTaskResponse,
18+
GetTaskPushNotificationConfigRequest,
19+
GetTaskPushNotificationConfigResponse,
20+
GetTaskRequest,
21+
GetTaskResponse,
22+
SendMessageRequest,
23+
SendMessageResponse,
24+
SendStreamingMessageRequest,
25+
SendStreamingMessageResponse,
26+
SetTaskPushNotificationConfigRequest,
27+
SetTaskPushNotificationConfigResponse,
28+
)
2029
from a2a.utils.telemetry import SpanKind, trace_class
2130

31+
2232
logger = logging.getLogger(__name__)
2333

34+
2435
class A2ACardResolver:
2536
"""Agent Card resolver."""
2637

@@ -160,6 +171,7 @@ async def get_client_from_agent_card_url(
160171
agent_card_path: The path to the agent card endpoint, relative to the base URL.
161172
http_kwargs: Optional dictionary of keyword arguments to pass to the
162173
underlying httpx.get request when fetching the agent card.
174+
163175
Returns:
164176
An initialized `A2AClient` instance.
165177
@@ -169,7 +181,9 @@ async def get_client_from_agent_card_url(
169181
"""
170182
agent_card: AgentCard = await A2ACardResolver(
171183
httpx_client, base_url=base_url, agent_card_path=agent_card_path
172-
).get_agent_card(http_kwargs=http_kwargs) # Fetches public card by default
184+
).get_agent_card(
185+
http_kwargs=http_kwargs
186+
) # Fetches public card by default
173187
return A2AClient(httpx_client=httpx_client, agent_card=agent_card)
174188

175189
async def send_message(

src/a2a/server/apps/starlette_app.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,20 @@
5151

5252
logger = logging.getLogger(__name__)
5353

54-
# Register Starlette User as an implementation of a2a.auth.user.User
55-
A2AUser.register(BaseUser)
54+
55+
class StarletteUserProxy(A2AUser):
56+
"""Adapts the Starlette User class to the A2A user representation."""
57+
58+
def __init__(self, user: BaseUser):
59+
self._user = user
60+
61+
@property
62+
def is_authenticated(self):
63+
return self._user.is_authenticated
64+
65+
@property
66+
def user_name(self):
67+
return self._user.display_name
5668

5769

5870
class CallContextBuilder(ABC):
@@ -70,7 +82,7 @@ def build(self, request: Request) -> ServerCallContext:
7082
user = UnauthenticatedUser()
7183
state = {}
7284
with contextlib.suppress(Exception):
73-
user = request.user
85+
user = StarletteUserProxy(request.user)
7486
state['auth'] = request.auth
7587
return ServerCallContext(user=user, state=state)
7688

src/a2a/server/events/event_queue.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17-
Event = (
18-
Message
19-
| Task
20-
| TaskStatusUpdateEvent
21-
| TaskArtifactUpdateEvent
22-
)
17+
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
2318
"""Type alias for events that can be enqueued."""
2419

2520

src/a2a/server/tasks/task_updater.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22

3+
from datetime import datetime, timezone
34
from typing import Any
45

56
from a2a.server.events import EventQueue
@@ -34,15 +35,23 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
3435
self.context_id = context_id
3536

3637
def update_status(
37-
self, state: TaskState, message: Message | None = None, final=False
38+
self,
39+
state: TaskState,
40+
message: Message | None = None,
41+
final=False,
42+
timestamp: str | None = None,
3843
):
3944
"""Updates the status of the task and publishes a `TaskStatusUpdateEvent`.
4045
4146
Args:
4247
state: The new state of the task.
4348
message: An optional message associated with the status update.
4449
final: If True, indicates this is the final status update for the task.
50+
timestamp: Optional ISO 8601 datetime string. Defaults to current time.
4551
"""
52+
current_timestamp = (
53+
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
54+
)
4655
self.event_queue.enqueue_event(
4756
TaskStatusUpdateEvent(
4857
taskId=self.task_id,
@@ -51,6 +60,7 @@ def update_status(
5160
status=TaskStatus(
5261
state=state,
5362
message=message,
63+
timestamp=current_timestamp,
5464
),
5565
)
5666
)
@@ -97,6 +107,10 @@ def failed(self, message: Message | None = None):
97107
"""Marks the task as failed and publishes a final status update."""
98108
self.update_status(TaskState.failed, message=message, final=True)
99109

110+
def reject(self, message: Message | None = None):
111+
"""Marks the task as rejected and publishes a final status update."""
112+
self.update_status(TaskState.rejected, message=message, final=True)
113+
100114
def submit(self, message: Message | None = None):
101115
"""Marks the task as submitted and publishes a status update."""
102116
self.update_status(

src/a2a/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,10 @@ class PushNotificationConfig(BaseModel):
648648
"""
649649

650650
authentication: PushNotificationAuthenticationInfo | None = None
651+
id: str | None = None
652+
"""
653+
Push Notification ID - created by server to support multiple callbacks
654+
"""
651655
token: str | None = None
652656
"""
653657
Token unique to this task/session.

tests/client/test_client.py

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,45 @@
11
import json
2+
23
from collections.abc import AsyncGenerator
34
from typing import Any
45
from unittest.mock import AsyncMock, MagicMock, patch
56

67
import httpx
78
import pytest
9+
810
from httpx_sse import EventSource, ServerSentEvent
9-
from pydantic import ValidationError as PydanticValidationError
10-
11-
from a2a.client import (A2ACardResolver, A2AClient, A2AClientHTTPError,
12-
A2AClientJSONError, create_text_message_object)
13-
from a2a.types import (A2ARequest, AgentCapabilities, AgentCard, AgentSkill,
14-
CancelTaskRequest, CancelTaskResponse,
15-
CancelTaskSuccessResponse, GetTaskRequest,
16-
GetTaskResponse, InvalidParamsError,
17-
JSONRPCErrorResponse, MessageSendParams, Role,
18-
SendMessageRequest, SendMessageResponse,
19-
SendMessageSuccessResponse, SendStreamingMessageRequest,
20-
SendStreamingMessageResponse, TaskIdParams,
21-
TaskNotCancelableError, TaskQueryParams)
11+
12+
from a2a.client import (
13+
A2ACardResolver,
14+
A2AClient,
15+
A2AClientHTTPError,
16+
A2AClientJSONError,
17+
create_text_message_object,
18+
)
19+
from a2a.types import (
20+
A2ARequest,
21+
AgentCapabilities,
22+
AgentCard,
23+
AgentSkill,
24+
CancelTaskRequest,
25+
CancelTaskResponse,
26+
CancelTaskSuccessResponse,
27+
GetTaskRequest,
28+
GetTaskResponse,
29+
InvalidParamsError,
30+
JSONRPCErrorResponse,
31+
MessageSendParams,
32+
Role,
33+
SendMessageRequest,
34+
SendMessageResponse,
35+
SendMessageSuccessResponse,
36+
SendStreamingMessageRequest,
37+
SendStreamingMessageResponse,
38+
TaskIdParams,
39+
TaskNotCancelableError,
40+
TaskQueryParams,
41+
)
42+
2243

2344
AGENT_CARD = AgentCard(
2445
name='Hello World Agent',
@@ -100,7 +121,9 @@ class TestA2ACardResolver:
100121
BASE_URL = 'http://example.com'
101122
AGENT_CARD_PATH = '/.well-known/agent.json'
102123
FULL_AGENT_CARD_URL = f'{BASE_URL}{AGENT_CARD_PATH}'
103-
EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard' # Default path
124+
EXTENDED_AGENT_CARD_PATH = (
125+
'/agent/authenticatedExtendedCard' # Default path
126+
)
104127

105128
@pytest.mark.asyncio
106129
async def test_init_strips_slashes(self, mock_httpx_client: AsyncMock):
@@ -141,11 +164,12 @@ async def test_get_agent_card_success_public_only(
141164

142165
@pytest.mark.asyncio
143166
async def test_get_agent_card_success_with_specified_path_for_extended_card(
144-
self, mock_httpx_client: AsyncMock):
167+
self, mock_httpx_client: AsyncMock
168+
):
145169
extended_card_response = AsyncMock(spec=httpx.Response)
146170
extended_card_response.status_code = 200
147-
extended_card_response.json.return_value = AGENT_CARD_EXTENDED.model_dump(
148-
mode='json'
171+
extended_card_response.json.return_value = (
172+
AGENT_CARD_EXTENDED.model_dump(mode='json')
149173
)
150174

151175
# Mock the single call for the extended card
@@ -156,20 +180,26 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card(
156180
base_url=self.BASE_URL,
157181
agent_card_path=self.AGENT_CARD_PATH,
158182
)
159-
183+
160184
# Fetch the extended card by providing its relative path and example auth
161-
auth_kwargs = {"headers": {"Authorization": "Bearer test token"}}
185+
auth_kwargs = {'headers': {'Authorization': 'Bearer test token'}}
162186
agent_card_result = await resolver.get_agent_card(
163187
relative_card_path=self.EXTENDED_AGENT_CARD_PATH,
164-
http_kwargs=auth_kwargs
188+
http_kwargs=auth_kwargs,
165189
)
166190

167-
expected_extended_url = f'{self.BASE_URL}/{self.EXTENDED_AGENT_CARD_PATH.lstrip("/")}'
168-
mock_httpx_client.get.assert_called_once_with(expected_extended_url, **auth_kwargs)
191+
expected_extended_url = (
192+
f'{self.BASE_URL}/{self.EXTENDED_AGENT_CARD_PATH.lstrip("/")}'
193+
)
194+
mock_httpx_client.get.assert_called_once_with(
195+
expected_extended_url, **auth_kwargs
196+
)
169197
extended_card_response.raise_for_status.assert_called_once()
170198

171199
assert isinstance(agent_card_result, AgentCard)
172-
assert agent_card_result == AGENT_CARD_EXTENDED # Should return the extended card
200+
assert (
201+
agent_card_result == AGENT_CARD_EXTENDED
202+
) # Should return the extended card
173203

174204
@pytest.mark.asyncio
175205
async def test_get_agent_card_validation_error(
@@ -178,19 +208,29 @@ async def test_get_agent_card_validation_error(
178208
mock_response = AsyncMock(spec=httpx.Response)
179209
mock_response.status_code = 200
180210
# Data that will cause a Pydantic ValidationError
181-
mock_response.json.return_value = {"invalid_field": "value", "name": "Test Agent"}
211+
mock_response.json.return_value = {
212+
'invalid_field': 'value',
213+
'name': 'Test Agent',
214+
}
182215
mock_httpx_client.get.return_value = mock_response
183216

184217
resolver = A2ACardResolver(
185218
httpx_client=mock_httpx_client, base_url=self.BASE_URL
186219
)
187220
# The call that is expected to raise an error should be within pytest.raises
188221
with pytest.raises(A2AClientJSONError) as exc_info:
189-
await resolver.get_agent_card() # Fetches from default path
190-
191-
assert f'Failed to validate agent card structure from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value)
192-
assert 'invalid_field' in str(exc_info.value) # Check if Pydantic error details are present
193-
assert mock_httpx_client.get.call_count == 1 # Should only be called once
222+
await resolver.get_agent_card() # Fetches from default path
223+
224+
assert (
225+
f'Failed to validate agent card structure from {self.FULL_AGENT_CARD_URL}'
226+
in str(exc_info.value)
227+
)
228+
assert 'invalid_field' in str(
229+
exc_info.value
230+
) # Check if Pydantic error details are present
231+
assert (
232+
mock_httpx_client.get.call_count == 1
233+
) # Should only be called once
194234

195235
@pytest.mark.asyncio
196236
async def test_get_agent_card_http_status_error(
@@ -217,7 +257,10 @@ async def test_get_agent_card_http_status_error(
217257
await resolver.get_agent_card()
218258

219259
assert exc_info.value.status_code == 404
220-
assert f'Failed to fetch agent card from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value)
260+
assert (
261+
f'Failed to fetch agent card from {self.FULL_AGENT_CARD_URL}'
262+
in str(exc_info.value)
263+
)
221264
assert 'Not Found' in str(exc_info.value)
222265
mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL)
223266

@@ -242,7 +285,10 @@ async def test_get_agent_card_json_decode_error(
242285
await resolver.get_agent_card()
243286

244287
# Assertions using exc_info must be after the with block
245-
assert f'Failed to parse JSON for agent card from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value)
288+
assert (
289+
f'Failed to parse JSON for agent card from {self.FULL_AGENT_CARD_URL}'
290+
in str(exc_info.value)
291+
)
246292
assert 'Expecting value' in str(exc_info.value)
247293
mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL)
248294

@@ -263,7 +309,10 @@ async def test_get_agent_card_request_error(
263309
await resolver.get_agent_card()
264310

265311
assert exc_info.value.status_code == 503
266-
assert f'Network communication error fetching agent card from {self.FULL_AGENT_CARD_URL}' in str(exc_info.value)
312+
assert (
313+
f'Network communication error fetching agent card from {self.FULL_AGENT_CARD_URL}'
314+
in str(exc_info.value)
315+
)
267316
assert 'Network issue' in str(exc_info.value)
268317
mock_httpx_client.get.assert_called_once_with(self.FULL_AGENT_CARD_URL)
269318

tests/server/tasks/test_task_updater.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ def test_new_agent_message(self, task_updater, sample_parts):
212212
assert message.parts == sample_parts
213213
assert message.metadata is None
214214

215-
def test_new_agent_message_with_metadata(
216-
self, task_updater, sample_parts
217-
):
215+
def test_new_agent_message_with_metadata(self, task_updater, sample_parts):
218216
"""Test creating a new agent message with metadata and final=True."""
219217
metadata = {'key': 'value'}
220218

0 commit comments

Comments
 (0)