Skip to content
15 changes: 11 additions & 4 deletions tests/client/test_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ClientFactory,
InMemoryContextCredentialStore,
)
from a2a.client.auth.credentials import InMemoryContextCredentialStore
from a2a.types import (
APIKeySecurityScheme,
AgentCapabilities,
Expand Down Expand Up @@ -106,7 +107,9 @@ def store():


@pytest.mark.asyncio
async def test_auth_interceptor_skips_when_no_agent_card(store):
async def test_auth_interceptor_skips_when_no_agent_card(
store: InMemoryContextCredentialStore,
):
"""
Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.
"""
Expand All @@ -126,7 +129,9 @@ async def test_auth_interceptor_skips_when_no_agent_card(store):


@pytest.mark.asyncio
async def test_in_memory_context_credential_store(store):
async def test_in_memory_context_credential_store(
store: InMemoryContextCredentialStore,
):
"""
Verifies that InMemoryContextCredentialStore correctly stores and retrieves
credentials based on the session ID in the client context.
Expand Down Expand Up @@ -284,7 +289,9 @@ class AuthTestCase:
[api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case],
)
@respx.mock
async def test_auth_interceptor_variants(test_case, store):
async def test_auth_interceptor_variants(
test_case: AuthTestCase, store: InMemoryContextCredentialStore
):
"""
Parametrized test verifying that AuthInterceptor correctly attaches credentials
based on the defined security scheme in the AgentCard.
Expand Down Expand Up @@ -329,7 +336,7 @@ async def test_auth_interceptor_variants(test_case, store):

@pytest.mark.asyncio
async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
store,
store: InMemoryContextCredentialStore,
):
"""
Tests that AuthInterceptor skips a scheme if it's listed in security requirements
Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@


@pytest.fixture
def mock_transport():
def mock_transport() -> AsyncMock:
return AsyncMock(spec=ClientTransport)


@pytest.fixture
def sample_agent_card():
def sample_agent_card() -> AgentCard:
return AgentCard(
name='Test Agent',
description='An agent for testing',
Expand All @@ -38,7 +38,7 @@ def sample_agent_card():


@pytest.fixture
def sample_message():
def sample_message() -> Message:
return Message(
role=Role.user,
message_id='msg-1',
Expand All @@ -47,7 +47,7 @@ def sample_message():


@pytest.fixture
def base_client(sample_agent_card, mock_transport):
def base_client(sample_agent_card: AgentCard, mock_transport: AsyncMock):
config = ClientConfig(streaming=True)
return BaseClient(
card=sample_agent_card,
Expand Down
6 changes: 3 additions & 3 deletions tests/client/test_client_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@


@pytest.fixture
def task_manager():
def task_manager() -> ClientTaskManager:
return ClientTaskManager()


@pytest.fixture
def sample_task():
def sample_task() -> Task:
return Task(
id='task123',
context_id='context456',
Expand All @@ -38,7 +38,7 @@ def sample_task():


@pytest.fixture
def sample_message():
def sample_message() -> Message:
return Message(
message_id='msg1',
role=Role.user,
Expand Down
4 changes: 2 additions & 2 deletions tests/client/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_raising_base_error(self):
(500, 'Server Error', 'HTTP Error 500: Server Error'),
],
)
def test_http_error_parametrized(status_code, message, expected):
def test_http_error_parametrized(status_code: int, message: str, expected: str):
"""Parametrized test for HTTP errors with different status codes."""
error = A2AClientHTTPError(status_code, message)
assert error.status_code == status_code
Expand All @@ -194,7 +194,7 @@ def test_http_error_parametrized(status_code, message, expected):
('Parsing failed', 'JSON Error: Parsing failed'),
],
)
def test_json_error_parametrized(message, expected):
def test_json_error_parametrized(message: str, expected: str):
"""Parametrized test for JSON errors with different messages."""
error = A2AClientJSONError(message)
assert error.message == message
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sample_task_status_update_event() -> TaskStatusUpdateEvent:

@pytest.fixture
def sample_task_artifact_update_event(
sample_artifact,
sample_artifact: Artifact,
) -> TaskArtifactUpdateEvent:
"""Provides a sample TaskArtifactUpdateEvent."""
return TaskArtifactUpdateEvent(
Expand Down
73 changes: 39 additions & 34 deletions tests/server/agent_execution/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,22 @@ class TestRequestContext:
"""Tests for the RequestContext class."""

@pytest.fixture
def mock_message(self):
def mock_message(self) -> Mock:
"""Fixture for a mock Message."""

return Mock(spec=Message, task_id=None, context_id=None)

@pytest.fixture
def mock_params(self, mock_message):
def mock_params(self, mock_message: Mock) -> Mock:
"""Fixture for a mock MessageSendParams."""
return Mock(spec=MessageSendParams, message=mock_message)

@pytest.fixture
def mock_task(self):
def mock_task(self) -> Mock:
"""Fixture for a mock Task."""
return Mock(spec=Task, id='task-123', context_id='context-456')

def test_init_without_params(self):
def test_init_without_params(self) -> None:
"""Test initialization without parameters."""
context = RequestContext()
assert context.message is None
Expand All @@ -42,7 +43,7 @@ def test_init_without_params(self):
assert context.current_task is None
assert context.related_tasks == []

def test_init_with_params_no_ids(self, mock_params):
def test_init_with_params_no_ids(self, mock_params: Mock) -> None:
"""Test initialization with params but no task or context IDs."""
with patch(
'uuid.uuid4',
Expand All @@ -65,23 +66,23 @@ def test_init_with_params_no_ids(self, mock_params):
== '00000000-0000-0000-0000-000000000002'
)

def test_init_with_task_id(self, mock_params):
def test_init_with_task_id(self, mock_params: Mock) -> None:
"""Test initialization with task ID provided."""
task_id = 'task-123'
context = RequestContext(request=mock_params, task_id=task_id)

assert context.task_id == task_id
assert mock_params.message.task_id == task_id

def test_init_with_context_id(self, mock_params):
def test_init_with_context_id(self, mock_params: Mock) -> None:
"""Test initialization with context ID provided."""
context_id = 'context-456'
context = RequestContext(request=mock_params, context_id=context_id)

assert context.context_id == context_id
assert mock_params.message.context_id == context_id

def test_init_with_both_ids(self, mock_params):
def test_init_with_both_ids(self, mock_params: Mock) -> None:
"""Test initialization with both task and context IDs provided."""
task_id = 'task-123'
context_id = 'context-456'
Expand All @@ -94,7 +95,7 @@ def test_init_with_both_ids(self, mock_params):
assert context.context_id == context_id
assert mock_params.message.context_id == context_id

def test_init_with_task(self, mock_params, mock_task):
def test_init_with_task(self, mock_params: Mock, mock_task: Mock) -> None:
"""Test initialization with a task object."""
context = RequestContext(request=mock_params, task=mock_task)

Expand All @@ -105,7 +106,7 @@ def test_get_user_input_no_params(self):
context = RequestContext()
assert context.get_user_input() == ''

def test_attach_related_task(self, mock_task):
def test_attach_related_task(self, mock_task: Mock):
"""Test attach_related_task adds a task to related_tasks."""
context = RequestContext()
assert len(context.related_tasks) == 0
Expand All @@ -120,7 +121,7 @@ def test_attach_related_task(self, mock_task):
assert len(context.related_tasks) == 2
assert context.related_tasks[1] == another_task

def test_current_task_property(self, mock_task):
def test_current_task_property(self, mock_task: Mock) -> None:
"""Test current_task getter and setter."""
context = RequestContext()
assert context.current_task is None
Expand All @@ -133,13 +134,15 @@ def test_current_task_property(self, mock_task):
context.current_task = new_task
assert context.current_task == new_task

def test_check_or_generate_task_id_no_params(self):
def test_check_or_generate_task_id_no_params(self) -> None:
"""Test _check_or_generate_task_id with no params does nothing."""
context = RequestContext()
context._check_or_generate_task_id()
assert context.task_id is None

def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
def test_check_or_generate_task_id_with_existing_task_id(
self, mock_params: Mock
) -> None:
"""Test _check_or_generate_task_id with existing task ID."""
existing_id = 'existing-task-id'
mock_params.message.task_id = existing_id
Expand All @@ -151,8 +154,8 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
assert mock_params.message.task_id == existing_id

def test_check_or_generate_task_id_with_custom_id_generator(
self, mock_params
):
self, mock_params: Mock
) -> None:
"""Test _check_or_generate_task_id uses custom ID generator when provided."""
id_generator = Mock(spec=IDGenerator)
id_generator.generate.return_value = 'custom-task-id'
Expand All @@ -164,14 +167,14 @@ def test_check_or_generate_task_id_with_custom_id_generator(

assert context.task_id == 'custom-task-id'

def test_check_or_generate_context_id_no_params(self):
def test_check_or_generate_context_id_no_params(self) -> None:
"""Test _check_or_generate_context_id with no params does nothing."""
context = RequestContext()
context._check_or_generate_context_id()
assert context.context_id is None

def test_check_or_generate_context_id_with_existing_context_id(
self, mock_params
self, mock_params: Mock
):
"""Test _check_or_generate_context_id with existing context ID."""
existing_id = 'existing-context-id'
Expand All @@ -184,8 +187,8 @@ def test_check_or_generate_context_id_with_existing_context_id(
assert mock_params.message.context_id == existing_id

def test_check_or_generate_context_id_with_custom_id_generator(
self, mock_params
):
self, mock_params: Mock
) -> None:
"""Test _check_or_generate_context_id uses custom ID generator when provided."""
id_generator = Mock(spec=IDGenerator)
id_generator.generate.return_value = 'custom-context-id'
Expand All @@ -198,8 +201,8 @@ def test_check_or_generate_context_id_with_custom_id_generator(
assert context.context_id == 'custom-context-id'

def test_init_raises_error_on_task_id_mismatch(
self, mock_params, mock_task
):
self, mock_params: Mock, mock_task: Mock
) -> None:
"""Test that an error is raised if provided task_id mismatches task.id."""
with pytest.raises(ServerError) as exc_info:
RequestContext(
Expand All @@ -208,8 +211,8 @@ def test_init_raises_error_on_task_id_mismatch(
assert 'bad task id' in str(exc_info.value.error.message)

def test_init_raises_error_on_context_id_mismatch(
self, mock_params, mock_task
):
self, mock_params: Mock, mock_task: Mock
) -> None:
"""Test that an error is raised if provided context_id mismatches task.context_id."""
# Set a valid task_id to avoid that error
mock_params.message.task_id = mock_task.id
Expand All @@ -224,36 +227,38 @@ def test_init_raises_error_on_context_id_mismatch(

assert 'bad context id' in str(exc_info.value.error.message)

def test_with_related_tasks_provided(self, mock_task):
def test_with_related_tasks_provided(self, mock_task: Mock) -> None:
"""Test initialization with related tasks provided."""
related_tasks = [mock_task, Mock(spec=Task)]
context = RequestContext(related_tasks=related_tasks)

assert context.related_tasks == related_tasks
assert len(context.related_tasks) == 2

def test_message_property_without_params(self):
def test_message_property_without_params(self) -> None:
"""Test message property returns None when no params are provided."""
context = RequestContext()
assert context.message is None

def test_message_property_with_params(self, mock_params):
def test_message_property_with_params(self, mock_params: Mock) -> None:
"""Test message property returns the message from params."""
context = RequestContext(request=mock_params)
assert context.message == mock_params.message

def test_metadata_property_without_content(self):
def test_metadata_property_without_content(self) -> None:
"""Test metadata property returns empty dict when no content are provided."""
context = RequestContext()
assert context.metadata == {}

def test_metadata_property_with_content(self, mock_params):
def test_metadata_property_with_content(self, mock_params: Mock) -> None:
"""Test metadata property returns the metadata from params."""
mock_params.metadata = {'key': 'value'}
context = RequestContext(request=mock_params)
assert context.metadata == {'key': 'value'}

def test_init_with_existing_ids_in_message(self, mock_message, mock_params):
def test_init_with_existing_ids_in_message(
self, mock_message, mock_params
) -> None:
"""Test initialization with existing IDs in the message."""
mock_message.task_id = 'existing-task-id'
mock_message.context_id = 'existing-context-id'
Expand All @@ -265,8 +270,8 @@ def test_init_with_existing_ids_in_message(self, mock_message, mock_params):
# No new UUIDs should be generated

def test_init_with_task_id_and_existing_task_id_match(
self, mock_params, mock_task
):
self, mock_params: Mock, mock_task
) -> None:
"""Test initialization succeeds when task_id matches task.id."""
mock_params.message.task_id = mock_task.id

Expand All @@ -278,8 +283,8 @@ def test_init_with_task_id_and_existing_task_id_match(
assert context.current_task == mock_task

def test_init_with_context_id_and_existing_context_id_match(
self, mock_params, mock_task
):
self, mock_params: Mock, mock_task: Mock
) -> None:
"""Test initialization succeeds when context_id matches task.context_id."""
mock_params.message.task_id = mock_task.id # Set matching task ID
mock_params.message.context_id = mock_task.context_id
Expand All @@ -294,7 +299,7 @@ def test_init_with_context_id_and_existing_context_id_match(
assert context.context_id == mock_task.context_id
assert context.current_task == mock_task

def test_extension_handling(self):
def test_extension_handling(self) -> None:
"""Test extension handling in RequestContext."""
call_context = ServerCallContext(requested_extensions={'foo', 'bar'})
context = RequestContext(call_context=call_context)
Expand Down
Loading
Loading