Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2844180

Browse files
committedJun 2, 2025··
Add integration test for server auth
1 parent f8e3af0 commit 2844180

File tree

1 file changed

+101
-15
lines changed

1 file changed

+101
-15
lines changed
 

‎tests/server/test_integration.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,42 @@
33
from unittest import mock
44

55
import pytest
6+
from starlette.authentication import (
7+
AuthCredentials,
8+
AuthenticationBackend,
9+
BaseUser,
10+
SimpleUser,
11+
)
12+
from starlette.middleware import Middleware
13+
from starlette.middleware.authentication import AuthenticationMiddleware
14+
from starlette.requests import HTTPConnection
615
from starlette.responses import JSONResponse
716
from starlette.routing import Route
817
from starlette.testclient import TestClient
918

1019
from a2a.server.apps.starlette_app import A2AStarletteApplication
11-
from a2a.types import (AgentCapabilities, AgentCard, Artifact, DataPart,
12-
InternalError, InvalidRequestError, JSONParseError,
13-
Part, PushNotificationConfig, Task,
14-
TaskArtifactUpdateEvent, TaskPushNotificationConfig,
15-
TaskState, TaskStatus, TextPart,
16-
UnsupportedOperationError)
20+
from a2a.types import (
21+
AgentCapabilities,
22+
AgentCard,
23+
Artifact,
24+
DataPart,
25+
InternalError,
26+
InvalidRequestError,
27+
JSONParseError,
28+
Message,
29+
Part,
30+
PushNotificationConfig,
31+
Role,
32+
SendMessageResponse,
33+
SendMessageSuccessResponse,
34+
Task,
35+
TaskArtifactUpdateEvent,
36+
TaskPushNotificationConfig,
37+
TaskState,
38+
TaskStatus,
39+
TextPart,
40+
UnsupportedOperationError,
41+
)
1742
from a2a.utils.errors import MethodNotImplementedError
1843

1944
# === TEST SETUP ===
@@ -106,9 +131,9 @@ def app(agent_card: AgentCard, handler: mock.AsyncMock):
106131

107132

108133
@pytest.fixture
109-
def client(app: A2AStarletteApplication):
134+
def client(app: A2AStarletteApplication, **kwargs):
110135
"""Create a test client with the app."""
111-
return TestClient(app.build())
136+
return TestClient(app.build(**kwargs))
112137

113138

114139
# === BASIC FUNCTIONALITY TESTS ===
@@ -135,7 +160,7 @@ def test_authenticated_extended_agent_card_endpoint_not_supported(
135160
# So, building the app and trying to hit it should result in 404 from Starlette itself
136161
client = TestClient(app_instance.build())
137162
response = client.get('/agent/authenticatedExtendedCard')
138-
assert response.status_code == 404 # Starlette's default for no route
163+
assert response.status_code == 404 # Starlette's default for no route
139164

140165

141166
def test_authenticated_extended_agent_card_endpoint_supported_with_specific_extended_card(
@@ -144,7 +169,9 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte
144169
handler: mock.AsyncMock,
145170
):
146171
"""Test extended card endpoint returns the specific extended card when provided."""
147-
agent_card.supportsAuthenticatedExtendedCard = True # Main card must support it
172+
agent_card.supportsAuthenticatedExtendedCard = (
173+
True # Main card must support it
174+
)
148175
app_instance = A2AStarletteApplication(
149176
agent_card, handler, extended_agent_card=extended_agent_card_fixture
150177
)
@@ -157,10 +184,9 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte
157184
assert data['name'] == extended_agent_card_fixture.name
158185
assert data['version'] == extended_agent_card_fixture.version
159186
assert len(data['skills']) == len(extended_agent_card_fixture.skills)
160-
assert any(
161-
skill['id'] == 'skill-extended' for skill in data['skills']
162-
), "Extended skill not found in served card"
163-
187+
assert any(skill['id'] == 'skill-extended' for skill in data['skills']), (
188+
'Extended skill not found in served card'
189+
)
164190

165191

166192
def test_agent_card_custom_url(
@@ -233,7 +259,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock):
233259
mock_task = Task(
234260
id='task1',
235261
contextId='session-xyz',
236-
state='completed',
237262
status=task_status,
238263
)
239264
handler.on_message_send.return_value = mock_task
@@ -402,6 +427,67 @@ def test_get_push_notification_config(
402427
handler.on_get_task_push_notification_config.assert_awaited_once()
403428

404429

430+
def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock):
431+
class TestAuthMiddleware(AuthenticationBackend):
432+
async def authenticate(
433+
self, conn: HTTPConnection
434+
) -> tuple[AuthCredentials, BaseUser] | None:
435+
# For the purposes of this test, all requests are authenticated!
436+
return (AuthCredentials(['authenticated']), SimpleUser('test_user'))
437+
438+
client = TestClient(
439+
app.build(
440+
middleware=[
441+
Middleware(
442+
AuthenticationMiddleware, backend=TestAuthMiddleware()
443+
)
444+
]
445+
)
446+
)
447+
448+
# Set the output message to be the authenticated user name
449+
handler.on_message_send.side_effect = lambda params, context: Message(
450+
contextId='session-xyz',
451+
messageId='112',
452+
role=Role.agent,
453+
parts=[
454+
Part(TextPart(text=context.user.user_name)),
455+
],
456+
)
457+
458+
# Send request
459+
response = client.post(
460+
'/',
461+
json={
462+
'jsonrpc': '2.0',
463+
'id': '123',
464+
'method': 'message/send',
465+
'params': {
466+
'message': {
467+
'role': 'agent',
468+
'parts': [{'kind': 'text', 'text': 'Hello'}],
469+
'messageId': '111',
470+
'kind': 'message',
471+
'taskId': 'task1',
472+
'contextId': 'session-xyz',
473+
}
474+
},
475+
},
476+
)
477+
478+
# Verify response
479+
assert response.status_code == 200
480+
result = SendMessageResponse.model_validate(response.json())
481+
assert isinstance(result.root, SendMessageSuccessResponse)
482+
assert isinstance(result.root.result, Message)
483+
message = result.root.result
484+
assert isinstance(message.parts[0].root, TextPart)
485+
assert message.parts[0].root.text == 'test_user'
486+
487+
# Verify handler was called
488+
handler.on_message_send.assert_awaited_once()
489+
490+
405491
# === STREAMING TESTS ===
406492

407493

0 commit comments

Comments
 (0)
Please sign in to comment.