Skip to content

Commit 3ea7dde

Browse files
Merge branch 'main' into fix-export-base-client
2 parents f97ce17 + 5268218 commit 3ea7dde

16 files changed

+411
-301
lines changed

tests/client/test_auth_middleware.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def store():
106106

107107

108108
@pytest.mark.asyncio
109-
async def test_auth_interceptor_skips_when_no_agent_card(store):
110-
"""
111-
Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.
112-
"""
109+
async def test_auth_interceptor_skips_when_no_agent_card(
110+
store: InMemoryContextCredentialStore,
111+
) -> None:
112+
"""Tests that the AuthInterceptor does not modify the request when no AgentCard is provided."""
113113
request_payload = {'foo': 'bar'}
114114
http_kwargs = {'fizz': 'buzz'}
115115
auth_interceptor = AuthInterceptor(credential_service=store)
@@ -126,9 +126,10 @@ async def test_auth_interceptor_skips_when_no_agent_card(store):
126126

127127

128128
@pytest.mark.asyncio
129-
async def test_in_memory_context_credential_store(store):
130-
"""
131-
Verifies that InMemoryContextCredentialStore correctly stores and retrieves
129+
async def test_in_memory_context_credential_store(
130+
store: InMemoryContextCredentialStore,
131+
) -> None:
132+
"""Verifies that InMemoryContextCredentialStore correctly stores and retrieves
132133
credentials based on the session ID in the client context.
133134
"""
134135
session_id = 'session-id'
@@ -163,11 +164,8 @@ async def test_in_memory_context_credential_store(store):
163164

164165
@pytest.mark.asyncio
165166
@respx.mock
166-
async def test_client_with_simple_interceptor():
167-
"""
168-
Ensures that a custom HeaderInterceptor correctly injects a static header
169-
into outbound HTTP requests from the A2AClient.
170-
"""
167+
async def test_client_with_simple_interceptor() -> None:
168+
"""Ensures that a custom HeaderInterceptor correctly injects a static header into outbound HTTP requests from the A2AClient."""
171169
url = 'http://agent.com/rpc'
172170
interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123')
173171
card = AgentCard(
@@ -196,9 +194,7 @@ async def test_client_with_simple_interceptor():
196194

197195
@dataclass
198196
class AuthTestCase:
199-
"""
200-
Represents a test scenario for verifying authentication behavior in AuthInterceptor.
201-
"""
197+
"""Represents a test scenario for verifying authentication behavior in AuthInterceptor."""
202198

203199
url: str
204200
"""The endpoint URL of the agent to which the request is sent."""
@@ -284,11 +280,10 @@ class AuthTestCase:
284280
[api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case],
285281
)
286282
@respx.mock
287-
async def test_auth_interceptor_variants(test_case, store):
288-
"""
289-
Parametrized test verifying that AuthInterceptor correctly attaches credentials
290-
based on the defined security scheme in the AgentCard.
291-
"""
283+
async def test_auth_interceptor_variants(
284+
test_case: AuthTestCase, store: InMemoryContextCredentialStore
285+
) -> None:
286+
"""Parametrized test verifying that AuthInterceptor correctly attaches credentials based on the defined security scheme in the AgentCard."""
292287
await store.set_credentials(
293288
test_case.session_id, test_case.scheme_name, test_case.credential
294289
)
@@ -329,12 +324,9 @@ async def test_auth_interceptor_variants(test_case, store):
329324

330325
@pytest.mark.asyncio
331326
async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
332-
store,
333-
):
334-
"""
335-
Tests that AuthInterceptor skips a scheme if it's listed in security requirements
336-
but not defined in security_schemes.
337-
"""
327+
store: InMemoryContextCredentialStore,
328+
) -> None:
329+
"""Tests that AuthInterceptor skips a scheme if it's listed in security requirements but not defined in security_schemes."""
338330
scheme_name = 'missing'
339331
session_id = 'session-id'
340332
credential = 'dummy-token'

tests/client/test_base_client.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020

2121
@pytest.fixture
22-
def mock_transport():
22+
def mock_transport() -> AsyncMock:
2323
return AsyncMock(spec=ClientTransport)
2424

2525

2626
@pytest.fixture
27-
def sample_agent_card():
27+
def sample_agent_card() -> AgentCard:
2828
return AgentCard(
2929
name='Test Agent',
3030
description='An agent for testing',
@@ -38,7 +38,7 @@ def sample_agent_card():
3838

3939

4040
@pytest.fixture
41-
def sample_message():
41+
def sample_message() -> Message:
4242
return Message(
4343
role=Role.user,
4444
message_id='msg-1',
@@ -47,7 +47,9 @@ def sample_message():
4747

4848

4949
@pytest.fixture
50-
def base_client(sample_agent_card, mock_transport):
50+
def base_client(
51+
sample_agent_card: AgentCard, mock_transport: AsyncMock
52+
) -> BaseClient:
5153
config = ClientConfig(streaming=True)
5254
return BaseClient(
5355
card=sample_agent_card,
@@ -61,7 +63,7 @@ def base_client(sample_agent_card, mock_transport):
6163
@pytest.mark.asyncio
6264
async def test_send_message_streaming(
6365
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
64-
):
66+
) -> None:
6567
async def create_stream(*args, **kwargs):
6668
yield Task(
6769
id='task-123',
@@ -82,7 +84,7 @@ async def create_stream(*args, **kwargs):
8284
@pytest.mark.asyncio
8385
async def test_send_message_non_streaming(
8486
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
85-
):
87+
) -> None:
8688
base_client._config.streaming = False
8789
mock_transport.send_message.return_value = Task(
8890
id='task-456',
@@ -101,7 +103,7 @@ async def test_send_message_non_streaming(
101103
@pytest.mark.asyncio
102104
async def test_send_message_non_streaming_agent_capability_false(
103105
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
104-
):
106+
) -> None:
105107
base_client._card.capabilities.streaming = False
106108
mock_transport.send_message.return_value = Task(
107109
id='task-789',

tests/client/test_client_task_manager.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323

2424
@pytest.fixture
25-
def task_manager():
25+
def task_manager() -> ClientTaskManager:
2626
return ClientTaskManager()
2727

2828

2929
@pytest.fixture
30-
def sample_task():
30+
def sample_task() -> Task:
3131
return Task(
3232
id='task123',
3333
context_id='context456',
@@ -38,29 +38,31 @@ def sample_task():
3838

3939

4040
@pytest.fixture
41-
def sample_message():
41+
def sample_message() -> Message:
4242
return Message(
4343
message_id='msg1',
4444
role=Role.user,
4545
parts=[Part(root=TextPart(text='Hello'))],
4646
)
4747

4848

49-
def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager):
49+
def test_get_task_no_task_id_returns_none(
50+
task_manager: ClientTaskManager,
51+
) -> None:
5052
assert task_manager.get_task() is None
5153

5254

5355
def test_get_task_or_raise_no_task_raises_error(
5456
task_manager: ClientTaskManager,
55-
):
57+
) -> None:
5658
with pytest.raises(A2AClientInvalidStateError, match='no current Task'):
5759
task_manager.get_task_or_raise()
5860

5961

6062
@pytest.mark.asyncio
6163
async def test_save_task_event_with_task(
6264
task_manager: ClientTaskManager, sample_task: Task
63-
):
65+
) -> None:
6466
await task_manager.save_task_event(sample_task)
6567
assert task_manager.get_task() == sample_task
6668
assert task_manager._task_id == sample_task.id
@@ -70,7 +72,7 @@ async def test_save_task_event_with_task(
7072
@pytest.mark.asyncio
7173
async def test_save_task_event_with_task_already_set_raises_error(
7274
task_manager: ClientTaskManager, sample_task: Task
73-
):
75+
) -> None:
7476
await task_manager.save_task_event(sample_task)
7577
with pytest.raises(
7678
A2AClientInvalidArgsError,
@@ -82,7 +84,7 @@ async def test_save_task_event_with_task_already_set_raises_error(
8284
@pytest.mark.asyncio
8385
async def test_save_task_event_with_status_update(
8486
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
85-
):
87+
) -> None:
8688
await task_manager.save_task_event(sample_task)
8789
status_update = TaskStatusUpdateEvent(
8890
task_id=sample_task.id,
@@ -98,7 +100,7 @@ async def test_save_task_event_with_status_update(
98100
@pytest.mark.asyncio
99101
async def test_save_task_event_with_artifact_update(
100102
task_manager: ClientTaskManager, sample_task: Task
101-
):
103+
) -> None:
102104
await task_manager.save_task_event(sample_task)
103105
artifact = Artifact(
104106
artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))]
@@ -119,7 +121,7 @@ async def test_save_task_event_with_artifact_update(
119121
@pytest.mark.asyncio
120122
async def test_save_task_event_creates_task_if_not_exists(
121123
task_manager: ClientTaskManager,
122-
):
124+
) -> None:
123125
status_update = TaskStatusUpdateEvent(
124126
task_id='new_task',
125127
context_id='new_context',
@@ -135,7 +137,7 @@ async def test_save_task_event_creates_task_if_not_exists(
135137
@pytest.mark.asyncio
136138
async def test_process_with_task_event(
137139
task_manager: ClientTaskManager, sample_task: Task
138-
):
140+
) -> None:
139141
with patch.object(
140142
task_manager, 'save_task_event', new_callable=AsyncMock
141143
) as mock_save:
@@ -144,7 +146,9 @@ async def test_process_with_task_event(
144146

145147

146148
@pytest.mark.asyncio
147-
async def test_process_with_non_task_event(task_manager: ClientTaskManager):
149+
async def test_process_with_non_task_event(
150+
task_manager: ClientTaskManager,
151+
) -> None:
148152
with patch.object(
149153
task_manager, 'save_task_event', new_callable=Mock
150154
) as mock_save:
@@ -155,14 +159,14 @@ async def test_process_with_non_task_event(task_manager: ClientTaskManager):
155159

156160
def test_update_with_message(
157161
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
158-
):
162+
) -> None:
159163
updated_task = task_manager.update_with_message(sample_message, sample_task)
160164
assert updated_task.history == [sample_message]
161165

162166

163167
def test_update_with_message_moves_status_message(
164168
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
165-
):
169+
) -> None:
166170
status_message = Message(
167171
message_id='status_msg',
168172
role=Role.agent,

0 commit comments

Comments
 (0)