Skip to content

Commit 2db7f81

Browse files
authored
test: add coverage for error handlers, constants, optionals, and models (#392)
1 parent 1aa4dd7 commit 2db7f81

File tree

4 files changed

+247
-0
lines changed

4 files changed

+247
-0
lines changed

tests/client/test_optionals.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Tests for a2a.client.optionals module."""
2+
3+
import importlib
4+
import sys
5+
6+
from unittest.mock import patch
7+
8+
9+
def test_channel_import_failure():
10+
"""Test Channel behavior when grpc is not available."""
11+
with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}):
12+
if 'a2a.client.optionals' in sys.modules:
13+
del sys.modules['a2a.client.optionals']
14+
15+
optionals = importlib.import_module('a2a.client.optionals')
16+
assert optionals.Channel is None

tests/server/test_models.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Tests for a2a.server.models module."""
2+
3+
from unittest.mock import MagicMock
4+
5+
from sqlalchemy.orm import DeclarativeBase
6+
7+
from a2a.server.models import (
8+
PydanticListType,
9+
PydanticType,
10+
create_push_notification_config_model,
11+
create_task_model,
12+
)
13+
from a2a.types import Artifact, TaskState, TaskStatus, TextPart
14+
15+
16+
class TestPydanticType:
17+
"""Tests for PydanticType SQLAlchemy type decorator."""
18+
19+
def test_process_bind_param_with_pydantic_model(self):
20+
pydantic_type = PydanticType(TaskStatus)
21+
status = TaskStatus(state=TaskState.working)
22+
dialect = MagicMock()
23+
24+
result = pydantic_type.process_bind_param(status, dialect)
25+
assert result['state'] == 'working'
26+
assert result['message'] is None
27+
# TaskStatus may have other optional fields
28+
29+
def test_process_bind_param_with_none(self):
30+
pydantic_type = PydanticType(TaskStatus)
31+
dialect = MagicMock()
32+
33+
result = pydantic_type.process_bind_param(None, dialect)
34+
assert result is None
35+
36+
def test_process_result_value(self):
37+
pydantic_type = PydanticType(TaskStatus)
38+
dialect = MagicMock()
39+
40+
result = pydantic_type.process_result_value(
41+
{'state': 'completed', 'message': None}, dialect
42+
)
43+
assert isinstance(result, TaskStatus)
44+
assert result.state == 'completed'
45+
46+
47+
class TestPydanticListType:
48+
"""Tests for PydanticListType SQLAlchemy type decorator."""
49+
50+
def test_process_bind_param_with_list(self):
51+
pydantic_list_type = PydanticListType(Artifact)
52+
artifacts = [
53+
Artifact(
54+
artifact_id='1', parts=[TextPart(type='text', text='Hello')]
55+
),
56+
Artifact(
57+
artifact_id='2', parts=[TextPart(type='text', text='World')]
58+
),
59+
]
60+
dialect = MagicMock()
61+
62+
result = pydantic_list_type.process_bind_param(artifacts, dialect)
63+
assert len(result) == 2
64+
assert result[0]['artifactId'] == '1' # JSON mode uses camelCase
65+
assert result[1]['artifactId'] == '2'
66+
67+
def test_process_result_value_with_list(self):
68+
pydantic_list_type = PydanticListType(Artifact)
69+
dialect = MagicMock()
70+
data = [
71+
{'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]},
72+
{'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]},
73+
]
74+
75+
result = pydantic_list_type.process_result_value(data, dialect)
76+
assert len(result) == 2
77+
assert all(isinstance(art, Artifact) for art in result)
78+
assert result[0].artifact_id == '1'
79+
assert result[1].artifact_id == '2'
80+
81+
82+
def test_create_task_model():
83+
"""Test dynamic task model creation."""
84+
85+
# Create a fresh base to avoid table conflicts
86+
class TestBase(DeclarativeBase):
87+
pass
88+
89+
# Create with default table name
90+
default_task_model = create_task_model('test_tasks_1', TestBase)
91+
assert default_task_model.__tablename__ == 'test_tasks_1'
92+
assert default_task_model.__name__ == 'TaskModel_test_tasks_1'
93+
94+
# Create with custom table name
95+
custom_task_model = create_task_model('test_tasks_2', TestBase)
96+
assert custom_task_model.__tablename__ == 'test_tasks_2'
97+
assert custom_task_model.__name__ == 'TaskModel_test_tasks_2'
98+
99+
100+
def test_create_push_notification_config_model():
101+
"""Test dynamic push notification config model creation."""
102+
103+
# Create a fresh base to avoid table conflicts
104+
class TestBase(DeclarativeBase):
105+
pass
106+
107+
# Create with default table name
108+
default_model = create_push_notification_config_model(
109+
'test_push_configs_1', TestBase
110+
)
111+
assert default_model.__tablename__ == 'test_push_configs_1'
112+
113+
# Create with custom table name
114+
custom_model = create_push_notification_config_model(
115+
'test_push_configs_2', TestBase
116+
)
117+
assert custom_model.__tablename__ == 'test_push_configs_2'
118+
assert 'test_push_configs_2' in custom_model.__name__

tests/utils/test_constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Tests for a2a.utils.constants module."""
2+
3+
from a2a.utils import constants
4+
5+
6+
def test_agent_card_constants():
7+
"""Test that agent card constants have expected values."""
8+
assert (
9+
constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json'
10+
)
11+
assert (
12+
constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json'
13+
)
14+
assert (
15+
constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard'
16+
)
17+
18+
19+
def test_default_rpc_url():
20+
"""Test default RPC URL constant."""
21+
assert constants.DEFAULT_RPC_URL == '/'

tests/utils/test_error_handlers.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Tests for a2a.utils.error_handlers module."""
2+
3+
from unittest.mock import patch
4+
5+
import pytest
6+
7+
from a2a.types import (
8+
InternalError,
9+
InvalidRequestError,
10+
MethodNotFoundError,
11+
TaskNotFoundError,
12+
)
13+
from a2a.utils.error_handlers import (
14+
A2AErrorToHttpStatus,
15+
rest_error_handler,
16+
rest_stream_error_handler,
17+
)
18+
from a2a.utils.errors import ServerError
19+
20+
21+
class MockJSONResponse:
22+
def __init__(self, content, status_code):
23+
self.content = content
24+
self.status_code = status_code
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_rest_error_handler_server_error():
29+
"""Test rest_error_handler with ServerError."""
30+
error = InvalidRequestError(message='Bad request')
31+
32+
@rest_error_handler
33+
async def failing_func():
34+
raise ServerError(error=error)
35+
36+
with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
37+
result = await failing_func()
38+
39+
assert isinstance(result, MockJSONResponse)
40+
assert result.status_code == 400
41+
assert result.content == {'message': 'Bad request'}
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_rest_error_handler_unknown_exception():
46+
"""Test rest_error_handler with unknown exception."""
47+
48+
@rest_error_handler
49+
async def failing_func():
50+
raise ValueError('Unexpected error')
51+
52+
with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
53+
result = await failing_func()
54+
55+
assert isinstance(result, MockJSONResponse)
56+
assert result.status_code == 500
57+
assert result.content == {'message': 'unknown exception'}
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_rest_stream_error_handler_server_error():
62+
"""Test rest_stream_error_handler with ServerError."""
63+
error = InternalError(message='Internal server error')
64+
65+
@rest_stream_error_handler
66+
async def failing_stream():
67+
raise ServerError(error=error)
68+
69+
with pytest.raises(ServerError) as exc_info:
70+
await failing_stream()
71+
72+
assert exc_info.value.error == error
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_rest_stream_error_handler_reraises_exception():
77+
"""Test rest_stream_error_handler reraises other exceptions."""
78+
79+
@rest_stream_error_handler
80+
async def failing_stream():
81+
raise RuntimeError('Stream failed')
82+
83+
with pytest.raises(RuntimeError, match='Stream failed'):
84+
await failing_stream()
85+
86+
87+
def test_a2a_error_to_http_status_mapping():
88+
"""Test A2AErrorToHttpStatus mapping."""
89+
assert A2AErrorToHttpStatus[InvalidRequestError] == 400
90+
assert A2AErrorToHttpStatus[MethodNotFoundError] == 404
91+
assert A2AErrorToHttpStatus[TaskNotFoundError] == 404
92+
assert A2AErrorToHttpStatus[InternalError] == 500

0 commit comments

Comments
 (0)