Skip to content

Commit 5268218

Browse files
HelloJowetholtskinnergemini-code-assist[bot]
authored
style: type issues in test folder (#521)
# Description I noticed there are quite a few type warnings and errors in this repo when running pyright. I went ahead and fixed some of the type issues in the `/tests` folder to help clean things up a bit. Strong typing helps catch bugs early and makes the codebase easier to maintain in the long run. I’d be happy to contribute more and tackle additional type problems if there’s interest. This PR is just a first step to see if it makes sense to spend more time here and maybe get more involved as a contributor. Let me know what you think! --------- Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent aa159f3 commit 5268218

16 files changed

+411
-301
lines changed

tests/client/test_auth_middleware.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def store():
106106

107107

108108
@pytest.mark.asyncio
109-
async def test_auth_interceptor_skips_when_no_agent_card(store):
110-
"""
111-
Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.
112-
"""
109+
async def test_auth_interceptor_skips_when_no_agent_card(
110+
store: InMemoryContextCredentialStore,
111+
) -> None:
112+
"""Tests that the AuthInterceptor does not modify the request when no AgentCard is provided."""
113113
request_payload = {'foo': 'bar'}
114114
http_kwargs = {'fizz': 'buzz'}
115115
auth_interceptor = AuthInterceptor(credential_service=store)
@@ -126,9 +126,10 @@ async def test_auth_interceptor_skips_when_no_agent_card(store):
126126

127127

128128
@pytest.mark.asyncio
129-
async def test_in_memory_context_credential_store(store):
130-
"""
131-
Verifies that InMemoryContextCredentialStore correctly stores and retrieves
129+
async def test_in_memory_context_credential_store(
130+
store: InMemoryContextCredentialStore,
131+
) -> None:
132+
"""Verifies that InMemoryContextCredentialStore correctly stores and retrieves
132133
credentials based on the session ID in the client context.
133134
"""
134135
session_id = 'session-id'
@@ -163,11 +164,8 @@ async def test_in_memory_context_credential_store(store):
163164

164165
@pytest.mark.asyncio
165166
@respx.mock
166-
async def test_client_with_simple_interceptor():
167-
"""
168-
Ensures that a custom HeaderInterceptor correctly injects a static header
169-
into outbound HTTP requests from the A2AClient.
170-
"""
167+
async def test_client_with_simple_interceptor() -> None:
168+
"""Ensures that a custom HeaderInterceptor correctly injects a static header into outbound HTTP requests from the A2AClient."""
171169
url = 'http://agent.com/rpc'
172170
interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123')
173171
card = AgentCard(
@@ -196,9 +194,7 @@ async def test_client_with_simple_interceptor():
196194

197195
@dataclass
198196
class AuthTestCase:
199-
"""
200-
Represents a test scenario for verifying authentication behavior in AuthInterceptor.
201-
"""
197+
"""Represents a test scenario for verifying authentication behavior in AuthInterceptor."""
202198

203199
url: str
204200
"""The endpoint URL of the agent to which the request is sent."""
@@ -284,11 +280,10 @@ class AuthTestCase:
284280
[api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case],
285281
)
286282
@respx.mock
287-
async def test_auth_interceptor_variants(test_case, store):
288-
"""
289-
Parametrized test verifying that AuthInterceptor correctly attaches credentials
290-
based on the defined security scheme in the AgentCard.
291-
"""
283+
async def test_auth_interceptor_variants(
284+
test_case: AuthTestCase, store: InMemoryContextCredentialStore
285+
) -> None:
286+
"""Parametrized test verifying that AuthInterceptor correctly attaches credentials based on the defined security scheme in the AgentCard."""
292287
await store.set_credentials(
293288
test_case.session_id, test_case.scheme_name, test_case.credential
294289
)
@@ -329,12 +324,9 @@ async def test_auth_interceptor_variants(test_case, store):
329324

330325
@pytest.mark.asyncio
331326
async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
332-
store,
333-
):
334-
"""
335-
Tests that AuthInterceptor skips a scheme if it's listed in security requirements
336-
but not defined in security_schemes.
337-
"""
327+
store: InMemoryContextCredentialStore,
328+
) -> None:
329+
"""Tests that AuthInterceptor skips a scheme if it's listed in security requirements but not defined in security_schemes."""
338330
scheme_name = 'missing'
339331
session_id = 'session-id'
340332
credential = 'dummy-token'

tests/client/test_base_client.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020

2121
@pytest.fixture
22-
def mock_transport():
22+
def mock_transport() -> AsyncMock:
2323
return AsyncMock(spec=ClientTransport)
2424

2525

2626
@pytest.fixture
27-
def sample_agent_card():
27+
def sample_agent_card() -> AgentCard:
2828
return AgentCard(
2929
name='Test Agent',
3030
description='An agent for testing',
@@ -38,7 +38,7 @@ def sample_agent_card():
3838

3939

4040
@pytest.fixture
41-
def sample_message():
41+
def sample_message() -> Message:
4242
return Message(
4343
role=Role.user,
4444
message_id='msg-1',
@@ -47,7 +47,9 @@ def sample_message():
4747

4848

4949
@pytest.fixture
50-
def base_client(sample_agent_card, mock_transport):
50+
def base_client(
51+
sample_agent_card: AgentCard, mock_transport: AsyncMock
52+
) -> BaseClient:
5153
config = ClientConfig(streaming=True)
5254
return BaseClient(
5355
card=sample_agent_card,
@@ -61,7 +63,7 @@ def base_client(sample_agent_card, mock_transport):
6163
@pytest.mark.asyncio
6264
async def test_send_message_streaming(
6365
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
64-
):
66+
) -> None:
6567
async def create_stream(*args, **kwargs):
6668
yield Task(
6769
id='task-123',
@@ -82,7 +84,7 @@ async def create_stream(*args, **kwargs):
8284
@pytest.mark.asyncio
8385
async def test_send_message_non_streaming(
8486
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
85-
):
87+
) -> None:
8688
base_client._config.streaming = False
8789
mock_transport.send_message.return_value = Task(
8890
id='task-456',
@@ -101,7 +103,7 @@ async def test_send_message_non_streaming(
101103
@pytest.mark.asyncio
102104
async def test_send_message_non_streaming_agent_capability_false(
103105
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
104-
):
106+
) -> None:
105107
base_client._card.capabilities.streaming = False
106108
mock_transport.send_message.return_value = Task(
107109
id='task-789',

tests/client/test_client_task_manager.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323

2424
@pytest.fixture
25-
def task_manager():
25+
def task_manager() -> ClientTaskManager:
2626
return ClientTaskManager()
2727

2828

2929
@pytest.fixture
30-
def sample_task():
30+
def sample_task() -> Task:
3131
return Task(
3232
id='task123',
3333
context_id='context456',
@@ -38,29 +38,31 @@ def sample_task():
3838

3939

4040
@pytest.fixture
41-
def sample_message():
41+
def sample_message() -> Message:
4242
return Message(
4343
message_id='msg1',
4444
role=Role.user,
4545
parts=[Part(root=TextPart(text='Hello'))],
4646
)
4747

4848

49-
def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager):
49+
def test_get_task_no_task_id_returns_none(
50+
task_manager: ClientTaskManager,
51+
) -> None:
5052
assert task_manager.get_task() is None
5153

5254

5355
def test_get_task_or_raise_no_task_raises_error(
5456
task_manager: ClientTaskManager,
55-
):
57+
) -> None:
5658
with pytest.raises(A2AClientInvalidStateError, match='no current Task'):
5759
task_manager.get_task_or_raise()
5860

5961

6062
@pytest.mark.asyncio
6163
async def test_save_task_event_with_task(
6264
task_manager: ClientTaskManager, sample_task: Task
63-
):
65+
) -> None:
6466
await task_manager.save_task_event(sample_task)
6567
assert task_manager.get_task() == sample_task
6668
assert task_manager._task_id == sample_task.id
@@ -70,7 +72,7 @@ async def test_save_task_event_with_task(
7072
@pytest.mark.asyncio
7173
async def test_save_task_event_with_task_already_set_raises_error(
7274
task_manager: ClientTaskManager, sample_task: Task
73-
):
75+
) -> None:
7476
await task_manager.save_task_event(sample_task)
7577
with pytest.raises(
7678
A2AClientInvalidArgsError,
@@ -82,7 +84,7 @@ async def test_save_task_event_with_task_already_set_raises_error(
8284
@pytest.mark.asyncio
8385
async def test_save_task_event_with_status_update(
8486
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
85-
):
87+
) -> None:
8688
await task_manager.save_task_event(sample_task)
8789
status_update = TaskStatusUpdateEvent(
8890
task_id=sample_task.id,
@@ -98,7 +100,7 @@ async def test_save_task_event_with_status_update(
98100
@pytest.mark.asyncio
99101
async def test_save_task_event_with_artifact_update(
100102
task_manager: ClientTaskManager, sample_task: Task
101-
):
103+
) -> None:
102104
await task_manager.save_task_event(sample_task)
103105
artifact = Artifact(
104106
artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))]
@@ -119,7 +121,7 @@ async def test_save_task_event_with_artifact_update(
119121
@pytest.mark.asyncio
120122
async def test_save_task_event_creates_task_if_not_exists(
121123
task_manager: ClientTaskManager,
122-
):
124+
) -> None:
123125
status_update = TaskStatusUpdateEvent(
124126
task_id='new_task',
125127
context_id='new_context',
@@ -135,7 +137,7 @@ async def test_save_task_event_creates_task_if_not_exists(
135137
@pytest.mark.asyncio
136138
async def test_process_with_task_event(
137139
task_manager: ClientTaskManager, sample_task: Task
138-
):
140+
) -> None:
139141
with patch.object(
140142
task_manager, 'save_task_event', new_callable=AsyncMock
141143
) as mock_save:
@@ -144,7 +146,9 @@ async def test_process_with_task_event(
144146

145147

146148
@pytest.mark.asyncio
147-
async def test_process_with_non_task_event(task_manager: ClientTaskManager):
149+
async def test_process_with_non_task_event(
150+
task_manager: ClientTaskManager,
151+
) -> None:
148152
with patch.object(
149153
task_manager, 'save_task_event', new_callable=Mock
150154
) as mock_save:
@@ -155,14 +159,14 @@ async def test_process_with_non_task_event(task_manager: ClientTaskManager):
155159

156160
def test_update_with_message(
157161
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
158-
):
162+
) -> None:
159163
updated_task = task_manager.update_with_message(sample_message, sample_task)
160164
assert updated_task.history == [sample_message]
161165

162166

163167
def test_update_with_message_moves_status_message(
164168
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
165-
):
169+
) -> None:
166170
status_message = Message(
167171
message_id='status_msg',
168172
role=Role.agent,

0 commit comments

Comments
 (0)