diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_middleware.py index 4f53ca3f..c41b4501 100644 --- a/tests/client/test_auth_middleware.py +++ b/tests/client/test_auth_middleware.py @@ -106,10 +106,10 @@ def store(): @pytest.mark.asyncio -async def test_auth_interceptor_skips_when_no_agent_card(store): - """ - Tests that the AuthInterceptor does not modify the request when no AgentCard is provided. - """ +async def test_auth_interceptor_skips_when_no_agent_card( + store: InMemoryContextCredentialStore, +) -> None: + """Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.""" request_payload = {'foo': 'bar'} http_kwargs = {'fizz': 'buzz'} auth_interceptor = AuthInterceptor(credential_service=store) @@ -126,9 +126,10 @@ async def test_auth_interceptor_skips_when_no_agent_card(store): @pytest.mark.asyncio -async def test_in_memory_context_credential_store(store): - """ - Verifies that InMemoryContextCredentialStore correctly stores and retrieves +async def test_in_memory_context_credential_store( + store: InMemoryContextCredentialStore, +) -> None: + """Verifies that InMemoryContextCredentialStore correctly stores and retrieves credentials based on the session ID in the client context. """ session_id = 'session-id' @@ -163,11 +164,8 @@ async def test_in_memory_context_credential_store(store): @pytest.mark.asyncio @respx.mock -async def test_client_with_simple_interceptor(): - """ - Ensures that a custom HeaderInterceptor correctly injects a static header - into outbound HTTP requests from the A2AClient. - """ +async def test_client_with_simple_interceptor() -> None: + """Ensures that a custom HeaderInterceptor correctly injects a static header into outbound HTTP requests from the A2AClient.""" url = 'http://agent.com/rpc' interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123') card = AgentCard( @@ -196,9 +194,7 @@ async def test_client_with_simple_interceptor(): @dataclass class AuthTestCase: - """ - Represents a test scenario for verifying authentication behavior in AuthInterceptor. - """ + """Represents a test scenario for verifying authentication behavior in AuthInterceptor.""" url: str """The endpoint URL of the agent to which the request is sent.""" @@ -284,11 +280,10 @@ class AuthTestCase: [api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case], ) @respx.mock -async def test_auth_interceptor_variants(test_case, store): - """ - Parametrized test verifying that AuthInterceptor correctly attaches credentials - based on the defined security scheme in the AgentCard. - """ +async def test_auth_interceptor_variants( + test_case: AuthTestCase, store: InMemoryContextCredentialStore +) -> None: + """Parametrized test verifying that AuthInterceptor correctly attaches credentials based on the defined security scheme in the AgentCard.""" await store.set_credentials( test_case.session_id, test_case.scheme_name, test_case.credential ) @@ -329,12 +324,9 @@ async def test_auth_interceptor_variants(test_case, store): @pytest.mark.asyncio async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes( - store, -): - """ - Tests that AuthInterceptor skips a scheme if it's listed in security requirements - but not defined in security_schemes. - """ + store: InMemoryContextCredentialStore, +) -> None: + """Tests that AuthInterceptor skips a scheme if it's listed in security requirements but not defined in security_schemes.""" scheme_name = 'missing' session_id = 'session-id' credential = 'dummy-token' diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index c1251f1c..d93a2203 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -19,12 +19,12 @@ @pytest.fixture -def mock_transport(): +def mock_transport() -> AsyncMock: return AsyncMock(spec=ClientTransport) @pytest.fixture -def sample_agent_card(): +def sample_agent_card() -> AgentCard: return AgentCard( name='Test Agent', description='An agent for testing', @@ -38,7 +38,7 @@ def sample_agent_card(): @pytest.fixture -def sample_message(): +def sample_message() -> Message: return Message( role=Role.user, message_id='msg-1', @@ -47,7 +47,9 @@ def sample_message(): @pytest.fixture -def base_client(sample_agent_card, mock_transport): +def base_client( + sample_agent_card: AgentCard, mock_transport: AsyncMock +) -> BaseClient: config = ClientConfig(streaming=True) return BaseClient( card=sample_agent_card, @@ -61,7 +63,7 @@ def base_client(sample_agent_card, mock_transport): @pytest.mark.asyncio async def test_send_message_streaming( base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): +) -> None: async def create_stream(*args, **kwargs): yield Task( id='task-123', @@ -82,7 +84,7 @@ async def create_stream(*args, **kwargs): @pytest.mark.asyncio async def test_send_message_non_streaming( base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): +) -> None: base_client._config.streaming = False mock_transport.send_message.return_value = Task( id='task-456', @@ -101,7 +103,7 @@ async def test_send_message_non_streaming( @pytest.mark.asyncio async def test_send_message_non_streaming_agent_capability_false( base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): +) -> None: base_client._card.capabilities.streaming = False mock_transport.send_message.return_value = Task( id='task-789', diff --git a/tests/client/test_client_task_manager.py b/tests/client/test_client_task_manager.py index b07ddceb..63f98d8b 100644 --- a/tests/client/test_client_task_manager.py +++ b/tests/client/test_client_task_manager.py @@ -22,12 +22,12 @@ @pytest.fixture -def task_manager(): +def task_manager() -> ClientTaskManager: return ClientTaskManager() @pytest.fixture -def sample_task(): +def sample_task() -> Task: return Task( id='task123', context_id='context456', @@ -38,7 +38,7 @@ def sample_task(): @pytest.fixture -def sample_message(): +def sample_message() -> Message: return Message( message_id='msg1', role=Role.user, @@ -46,13 +46,15 @@ def sample_message(): ) -def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager): +def test_get_task_no_task_id_returns_none( + task_manager: ClientTaskManager, +) -> None: assert task_manager.get_task() is None def test_get_task_or_raise_no_task_raises_error( task_manager: ClientTaskManager, -): +) -> None: with pytest.raises(A2AClientInvalidStateError, match='no current Task'): task_manager.get_task_or_raise() @@ -60,7 +62,7 @@ def test_get_task_or_raise_no_task_raises_error( @pytest.mark.asyncio async def test_save_task_event_with_task( task_manager: ClientTaskManager, sample_task: Task -): +) -> None: await task_manager.save_task_event(sample_task) assert task_manager.get_task() == sample_task assert task_manager._task_id == sample_task.id @@ -70,7 +72,7 @@ async def test_save_task_event_with_task( @pytest.mark.asyncio async def test_save_task_event_with_task_already_set_raises_error( task_manager: ClientTaskManager, sample_task: Task -): +) -> None: await task_manager.save_task_event(sample_task) with pytest.raises( A2AClientInvalidArgsError, @@ -82,7 +84,7 @@ async def test_save_task_event_with_task_already_set_raises_error( @pytest.mark.asyncio async def test_save_task_event_with_status_update( task_manager: ClientTaskManager, sample_task: Task, sample_message: Message -): +) -> None: await task_manager.save_task_event(sample_task) status_update = TaskStatusUpdateEvent( task_id=sample_task.id, @@ -98,7 +100,7 @@ async def test_save_task_event_with_status_update( @pytest.mark.asyncio async def test_save_task_event_with_artifact_update( task_manager: ClientTaskManager, sample_task: Task -): +) -> None: await task_manager.save_task_event(sample_task) artifact = Artifact( artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))] @@ -119,7 +121,7 @@ async def test_save_task_event_with_artifact_update( @pytest.mark.asyncio async def test_save_task_event_creates_task_if_not_exists( task_manager: ClientTaskManager, -): +) -> None: status_update = TaskStatusUpdateEvent( task_id='new_task', context_id='new_context', @@ -135,7 +137,7 @@ async def test_save_task_event_creates_task_if_not_exists( @pytest.mark.asyncio async def test_process_with_task_event( task_manager: ClientTaskManager, sample_task: Task -): +) -> None: with patch.object( task_manager, 'save_task_event', new_callable=AsyncMock ) as mock_save: @@ -144,7 +146,9 @@ async def test_process_with_task_event( @pytest.mark.asyncio -async def test_process_with_non_task_event(task_manager: ClientTaskManager): +async def test_process_with_non_task_event( + task_manager: ClientTaskManager, +) -> None: with patch.object( task_manager, 'save_task_event', new_callable=Mock ) as mock_save: @@ -155,14 +159,14 @@ async def test_process_with_non_task_event(task_manager: ClientTaskManager): def test_update_with_message( task_manager: ClientTaskManager, sample_task: Task, sample_message: Message -): +) -> None: updated_task = task_manager.update_with_message(sample_message, sample_task) assert updated_task.history == [sample_message] def test_update_with_message_moves_status_message( task_manager: ClientTaskManager, sample_task: Task, sample_message: Message -): +) -> None: status_message = Message( message_id='status_msg', role=Role.agent, diff --git a/tests/client/test_errors.py b/tests/client/test_errors.py index 30c4468d..60636bd3 100644 --- a/tests/client/test_errors.py +++ b/tests/client/test_errors.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import pytest from a2a.client import A2AClientError, A2AClientHTTPError, A2AClientJSONError @@ -6,13 +8,13 @@ class TestA2AClientError: """Test cases for the base A2AClientError class.""" - def test_instantiation(self): + def test_instantiation(self) -> None: """Test that A2AClientError can be instantiated.""" error = A2AClientError('Test error message') assert isinstance(error, Exception) assert str(error) == 'Test error message' - def test_inheritance(self): + def test_inheritance(self) -> None: """Test that A2AClientError inherits from Exception.""" error = A2AClientError() assert isinstance(error, Exception) @@ -21,31 +23,31 @@ def test_inheritance(self): class TestA2AClientHTTPError: """Test cases for A2AClientHTTPError class.""" - def test_instantiation(self): + def test_instantiation(self) -> None: """Test that A2AClientHTTPError can be instantiated with status_code and message.""" error = A2AClientHTTPError(404, 'Not Found') assert isinstance(error, A2AClientError) assert error.status_code == 404 assert error.message == 'Not Found' - def test_message_formatting(self): + def test_message_formatting(self) -> None: """Test that the error message is formatted correctly.""" error = A2AClientHTTPError(500, 'Internal Server Error') assert str(error) == 'HTTP Error 500: Internal Server Error' - def test_inheritance(self): + def test_inheritance(self) -> None: """Test that A2AClientHTTPError inherits from A2AClientError.""" error = A2AClientHTTPError(400, 'Bad Request') assert isinstance(error, A2AClientError) - def test_with_empty_message(self): + def test_with_empty_message(self) -> None: """Test behavior with an empty message.""" error = A2AClientHTTPError(403, '') assert error.status_code == 403 assert error.message == '' assert str(error) == 'HTTP Error 403: ' - def test_with_various_status_codes(self): + def test_with_various_status_codes(self) -> None: """Test with different HTTP status codes.""" test_cases = [ (200, 'OK'), @@ -68,29 +70,29 @@ def test_with_various_status_codes(self): class TestA2AClientJSONError: """Test cases for A2AClientJSONError class.""" - def test_instantiation(self): + def test_instantiation(self) -> None: """Test that A2AClientJSONError can be instantiated with a message.""" error = A2AClientJSONError('Invalid JSON format') assert isinstance(error, A2AClientError) assert error.message == 'Invalid JSON format' - def test_message_formatting(self): + def test_message_formatting(self) -> None: """Test that the error message is formatted correctly.""" error = A2AClientJSONError('Missing required field') assert str(error) == 'JSON Error: Missing required field' - def test_inheritance(self): + def test_inheritance(self) -> None: """Test that A2AClientJSONError inherits from A2AClientError.""" error = A2AClientJSONError('Parsing error') assert isinstance(error, A2AClientError) - def test_with_empty_message(self): + def test_with_empty_message(self) -> None: """Test behavior with an empty message.""" error = A2AClientJSONError('') assert error.message == '' assert str(error) == 'JSON Error: ' - def test_with_various_messages(self): + def test_with_various_messages(self) -> None: """Test with different error messages.""" test_messages = [ 'Malformed JSON', @@ -109,13 +111,13 @@ def test_with_various_messages(self): class TestExceptionHierarchy: """Test the exception hierarchy and relationships.""" - def test_exception_hierarchy(self): + def test_exception_hierarchy(self) -> None: """Test that the exception hierarchy is correct.""" assert issubclass(A2AClientError, Exception) assert issubclass(A2AClientHTTPError, A2AClientError) assert issubclass(A2AClientJSONError, A2AClientError) - def test_catch_specific_exception(self): + def test_catch_specific_exception(self) -> None: """Test that specific exceptions can be caught.""" try: raise A2AClientHTTPError(404, 'Not Found') @@ -123,7 +125,7 @@ def test_catch_specific_exception(self): assert e.status_code == 404 assert e.message == 'Not Found' - def test_catch_base_exception(self): + def test_catch_base_exception(self) -> None: """Test that derived exceptions can be caught as base exception.""" exceptions = [ A2AClientHTTPError(404, 'Not Found'), @@ -140,7 +142,7 @@ def test_catch_base_exception(self): class TestExceptionRaising: """Test cases for raising and handling the exceptions.""" - def test_raising_http_error(self): + def test_raising_http_error(self) -> NoReturn: """Test raising an HTTP error and checking its properties.""" with pytest.raises(A2AClientHTTPError) as excinfo: raise A2AClientHTTPError(429, 'Too Many Requests') @@ -150,7 +152,7 @@ def test_raising_http_error(self): assert error.message == 'Too Many Requests' assert str(error) == 'HTTP Error 429: Too Many Requests' - def test_raising_json_error(self): + def test_raising_json_error(self) -> NoReturn: """Test raising a JSON error and checking its properties.""" with pytest.raises(A2AClientJSONError) as excinfo: raise A2AClientJSONError('Invalid format') @@ -159,7 +161,7 @@ def test_raising_json_error(self): assert error.message == 'Invalid format' assert str(error) == 'JSON Error: Invalid format' - def test_raising_base_error(self): + def test_raising_base_error(self) -> NoReturn: """Test raising the base error.""" with pytest.raises(A2AClientError) as excinfo: raise A2AClientError('Generic client error') @@ -178,7 +180,9 @@ def test_raising_base_error(self): (500, 'Server Error', 'HTTP Error 500: Server Error'), ], ) -def test_http_error_parametrized(status_code, message, expected): +def test_http_error_parametrized( + status_code: int, message: str, expected: str +) -> None: """Parametrized test for HTTP errors with different status codes.""" error = A2AClientHTTPError(status_code, message) assert error.status_code == status_code @@ -194,7 +198,7 @@ def test_http_error_parametrized(status_code, message, expected): ('Parsing failed', 'JSON Error: Parsing failed'), ], ) -def test_json_error_parametrized(message, expected): +def test_json_error_parametrized(message: str, expected: str) -> None: """Parametrized test for JSON errors with different messages.""" error = A2AClientJSONError(message) assert error.message == message diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index 19f5abc1..6dab75e9 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -128,7 +128,7 @@ def sample_task_status_update_event() -> TaskStatusUpdateEvent: @pytest.fixture def sample_task_artifact_update_event( - sample_artifact, + sample_artifact: Artifact, ) -> TaskArtifactUpdateEvent: """Provides a sample TaskArtifactUpdateEvent.""" return TaskArtifactUpdateEvent( @@ -179,7 +179,7 @@ async def test_send_message_task_response( mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_task: Task, -): +) -> None: """Test send_message that returns a Task.""" mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( task=proto_utils.ToProto.task(sample_task) @@ -198,7 +198,7 @@ async def test_send_message_message_response( mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_message: Message, -): +) -> None: """Test send_message that returns a Message.""" mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( msg=proto_utils.ToProto.message(sample_message) @@ -223,7 +223,7 @@ async def test_send_message_streaming( # noqa: PLR0913 sample_task: Task, sample_task_status_update_event: TaskStatusUpdateEvent, sample_task_artifact_update_event: TaskArtifactUpdateEvent, -): +) -> None: """Test send_message_streaming that yields responses.""" stream = MagicMock() stream.read = AsyncMock( @@ -268,7 +268,7 @@ async def test_send_message_streaming( # noqa: PLR0913 @pytest.mark.asyncio async def test_get_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task -): +) -> None: """Test retrieving a task.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) params = TaskQueryParams(id=sample_task.id) @@ -286,7 +286,7 @@ async def test_get_task( @pytest.mark.asyncio async def test_get_task_with_history( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task -): +) -> None: """Test retrieving a task with history.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) history_len = 10 @@ -304,7 +304,7 @@ async def test_get_task_with_history( @pytest.mark.asyncio async def test_cancel_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task -): +) -> None: """Test cancelling a task.""" cancelled_task = sample_task.model_copy() cancelled_task.status.state = TaskState.canceled @@ -326,7 +326,7 @@ async def test_set_task_callback_with_valid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test setting a task push notification config with a valid task id.""" mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = ( proto_utils.ToProto.task_push_notification_config( @@ -355,7 +355,7 @@ async def test_set_task_callback_with_invalid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test setting a task push notification config with an invalid task id.""" mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( name=( @@ -382,7 +382,7 @@ async def test_get_task_callback_with_valid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test retrieving a task push notification config with a valid task id.""" mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( proto_utils.ToProto.task_push_notification_config( @@ -412,7 +412,7 @@ async def test_get_task_callback_with_invalid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test retrieving a task push notification config with an invalid task id.""" mock_grpc_stub.GetTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( name=( diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 684aecb2..979978ad 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -19,21 +19,21 @@ class TestRequestContext: """Tests for the RequestContext class.""" @pytest.fixture - def mock_message(self): + def mock_message(self) -> Mock: """Fixture for a mock Message.""" return Mock(spec=Message, task_id=None, context_id=None) @pytest.fixture - def mock_params(self, mock_message): + def mock_params(self, mock_message: Mock) -> Mock: """Fixture for a mock MessageSendParams.""" return Mock(spec=MessageSendParams, message=mock_message) @pytest.fixture - def mock_task(self): + def mock_task(self) -> Mock: """Fixture for a mock Task.""" return Mock(spec=Task, id='task-123', context_id='context-456') - def test_init_without_params(self): + def test_init_without_params(self) -> None: """Test initialization without parameters.""" context = RequestContext() assert context.message is None @@ -42,7 +42,7 @@ def test_init_without_params(self): assert context.current_task is None assert context.related_tasks == [] - def test_init_with_params_no_ids(self, mock_params): + def test_init_with_params_no_ids(self, mock_params: Mock) -> None: """Test initialization with params but no task or context IDs.""" with patch( 'uuid.uuid4', @@ -65,7 +65,7 @@ def test_init_with_params_no_ids(self, mock_params): == '00000000-0000-0000-0000-000000000002' ) - def test_init_with_task_id(self, mock_params): + def test_init_with_task_id(self, mock_params: Mock) -> None: """Test initialization with task ID provided.""" task_id = 'task-123' context = RequestContext(request=mock_params, task_id=task_id) @@ -73,7 +73,7 @@ def test_init_with_task_id(self, mock_params): assert context.task_id == task_id assert mock_params.message.task_id == task_id - def test_init_with_context_id(self, mock_params): + def test_init_with_context_id(self, mock_params: Mock) -> None: """Test initialization with context ID provided.""" context_id = 'context-456' context = RequestContext(request=mock_params, context_id=context_id) @@ -81,7 +81,7 @@ def test_init_with_context_id(self, mock_params): assert context.context_id == context_id assert mock_params.message.context_id == context_id - def test_init_with_both_ids(self, mock_params): + def test_init_with_both_ids(self, mock_params: Mock) -> None: """Test initialization with both task and context IDs provided.""" task_id = 'task-123' context_id = 'context-456' @@ -94,18 +94,18 @@ def test_init_with_both_ids(self, mock_params): assert context.context_id == context_id assert mock_params.message.context_id == context_id - def test_init_with_task(self, mock_params, mock_task): + def test_init_with_task(self, mock_params: Mock, mock_task: Mock) -> None: """Test initialization with a task object.""" context = RequestContext(request=mock_params, task=mock_task) assert context.current_task == mock_task - def test_get_user_input_no_params(self): + def test_get_user_input_no_params(self) -> None: """Test get_user_input with no params returns empty string.""" context = RequestContext() assert context.get_user_input() == '' - def test_attach_related_task(self, mock_task): + def test_attach_related_task(self, mock_task: Mock) -> None: """Test attach_related_task adds a task to related_tasks.""" context = RequestContext() assert len(context.related_tasks) == 0 @@ -120,7 +120,7 @@ def test_attach_related_task(self, mock_task): assert len(context.related_tasks) == 2 assert context.related_tasks[1] == another_task - def test_current_task_property(self, mock_task): + def test_current_task_property(self, mock_task: Mock) -> None: """Test current_task getter and setter.""" context = RequestContext() assert context.current_task is None @@ -133,13 +133,15 @@ def test_current_task_property(self, mock_task): context.current_task = new_task assert context.current_task == new_task - def test_check_or_generate_task_id_no_params(self): + def test_check_or_generate_task_id_no_params(self) -> None: """Test _check_or_generate_task_id with no params does nothing.""" context = RequestContext() context._check_or_generate_task_id() assert context.task_id is None - def test_check_or_generate_task_id_with_existing_task_id(self, mock_params): + def test_check_or_generate_task_id_with_existing_task_id( + self, mock_params: Mock + ) -> None: """Test _check_or_generate_task_id with existing task ID.""" existing_id = 'existing-task-id' mock_params.message.task_id = existing_id @@ -151,8 +153,8 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params): assert mock_params.message.task_id == existing_id def test_check_or_generate_task_id_with_custom_id_generator( - self, mock_params - ): + self, mock_params: Mock + ) -> None: """Test _check_or_generate_task_id uses custom ID generator when provided.""" id_generator = Mock(spec=IDGenerator) id_generator.generate.return_value = 'custom-task-id' @@ -164,15 +166,15 @@ def test_check_or_generate_task_id_with_custom_id_generator( assert context.task_id == 'custom-task-id' - def test_check_or_generate_context_id_no_params(self): + def test_check_or_generate_context_id_no_params(self) -> None: """Test _check_or_generate_context_id with no params does nothing.""" context = RequestContext() context._check_or_generate_context_id() assert context.context_id is None def test_check_or_generate_context_id_with_existing_context_id( - self, mock_params - ): + self, mock_params: Mock + ) -> None: """Test _check_or_generate_context_id with existing context ID.""" existing_id = 'existing-context-id' mock_params.message.context_id = existing_id @@ -184,8 +186,8 @@ def test_check_or_generate_context_id_with_existing_context_id( assert mock_params.message.context_id == existing_id def test_check_or_generate_context_id_with_custom_id_generator( - self, mock_params - ): + self, mock_params: Mock + ) -> None: """Test _check_or_generate_context_id uses custom ID generator when provided.""" id_generator = Mock(spec=IDGenerator) id_generator.generate.return_value = 'custom-context-id' @@ -198,8 +200,8 @@ def test_check_or_generate_context_id_with_custom_id_generator( assert context.context_id == 'custom-context-id' def test_init_raises_error_on_task_id_mismatch( - self, mock_params, mock_task - ): + self, mock_params: Mock, mock_task: Mock + ) -> None: """Test that an error is raised if provided task_id mismatches task.id.""" with pytest.raises(ServerError) as exc_info: RequestContext( @@ -208,8 +210,8 @@ def test_init_raises_error_on_task_id_mismatch( assert 'bad task id' in str(exc_info.value.error.message) def test_init_raises_error_on_context_id_mismatch( - self, mock_params, mock_task - ): + self, mock_params: Mock, mock_task: Mock + ) -> None: """Test that an error is raised if provided context_id mismatches task.context_id.""" # Set a valid task_id to avoid that error mock_params.message.task_id = mock_task.id @@ -224,7 +226,7 @@ def test_init_raises_error_on_context_id_mismatch( assert 'bad context id' in str(exc_info.value.error.message) - def test_with_related_tasks_provided(self, mock_task): + def test_with_related_tasks_provided(self, mock_task: Mock) -> None: """Test initialization with related tasks provided.""" related_tasks = [mock_task, Mock(spec=Task)] context = RequestContext(related_tasks=related_tasks) @@ -232,28 +234,30 @@ def test_with_related_tasks_provided(self, mock_task): assert context.related_tasks == related_tasks assert len(context.related_tasks) == 2 - def test_message_property_without_params(self): + def test_message_property_without_params(self) -> None: """Test message property returns None when no params are provided.""" context = RequestContext() assert context.message is None - def test_message_property_with_params(self, mock_params): + def test_message_property_with_params(self, mock_params: Mock) -> None: """Test message property returns the message from params.""" context = RequestContext(request=mock_params) assert context.message == mock_params.message - def test_metadata_property_without_content(self): + def test_metadata_property_without_content(self) -> None: """Test metadata property returns empty dict when no content are provided.""" context = RequestContext() assert context.metadata == {} - def test_metadata_property_with_content(self, mock_params): + def test_metadata_property_with_content(self, mock_params: Mock) -> None: """Test metadata property returns the metadata from params.""" mock_params.metadata = {'key': 'value'} context = RequestContext(request=mock_params) assert context.metadata == {'key': 'value'} - def test_init_with_existing_ids_in_message(self, mock_message, mock_params): + def test_init_with_existing_ids_in_message( + self, mock_message: Mock, mock_params: Mock + ) -> None: """Test initialization with existing IDs in the message.""" mock_message.task_id = 'existing-task-id' mock_message.context_id = 'existing-context-id' @@ -265,8 +269,8 @@ def test_init_with_existing_ids_in_message(self, mock_message, mock_params): # No new UUIDs should be generated def test_init_with_task_id_and_existing_task_id_match( - self, mock_params, mock_task - ): + self, mock_params: Mock, mock_task: Mock + ) -> None: """Test initialization succeeds when task_id matches task.id.""" mock_params.message.task_id = mock_task.id @@ -278,8 +282,8 @@ def test_init_with_task_id_and_existing_task_id_match( assert context.current_task == mock_task def test_init_with_context_id_and_existing_context_id_match( - self, mock_params, mock_task - ): + self, mock_params: Mock, mock_task: Mock + ) -> None: """Test initialization succeeds when context_id matches task.context_id.""" mock_params.message.task_id = mock_task.id # Set matching task ID mock_params.message.context_id = mock_task.context_id @@ -294,7 +298,7 @@ def test_init_with_context_id_and_existing_context_id_match( assert context.context_id == mock_task.context_id assert context.current_task == mock_task - def test_extension_handling(self): + def test_extension_handling(self) -> None: """Test extension handling in RequestContext.""" call_context = ServerCallContext(requested_extensions={'foo', 'bar'}) context = RequestContext(call_context=call_context) diff --git a/tests/server/agent_execution/test_simple_request_context_builder.py b/tests/server/agent_execution/test_simple_request_context_builder.py index 116f4011..5e1b8fd8 100644 --- a/tests/server/agent_execution/test_simple_request_context_builder.py +++ b/tests/server/agent_execution/test_simple_request_context_builder.py @@ -26,23 +26,25 @@ # Helper to create a simple message def create_sample_message( - content='test message', - msg_id='msg1', - role=Role.user, - reference_task_ids=None, -): + content: str = 'test message', + msg_id: str = 'msg1', + role: Role = Role.user, + reference_task_ids: list[str] | None = None, +) -> Message: return Message( message_id=msg_id, role=role, parts=[Part(root=TextPart(text=content))], - referenceTaskIds=reference_task_ids if reference_task_ids else [], + reference_task_ids=reference_task_ids if reference_task_ids else [], ) # Helper to create a simple task def create_sample_task( - task_id='task1', status_state=TaskState.submitted, context_id='ctx1' -): + task_id: str = 'task1', + status_state: TaskState = TaskState.submitted, + context_id: str = 'ctx1', +) -> Task: return Task( id=task_id, context_id=context_id, @@ -51,24 +53,24 @@ def create_sample_task( class TestSimpleRequestContextBuilder(unittest.IsolatedAsyncioTestCase): - def setUp(self): + def setUp(self) -> None: self.mock_task_store = AsyncMock(spec=TaskStore) - def test_init_with_populate_true_and_task_store(self): + def test_init_with_populate_true_and_task_store(self) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=True, task_store=self.mock_task_store ) self.assertTrue(builder._should_populate_referred_tasks) self.assertEqual(builder._task_store, self.mock_task_store) - def test_init_with_populate_false_task_store_none(self): + def test_init_with_populate_false_task_store_none(self) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=False, task_store=None ) self.assertFalse(builder._should_populate_referred_tasks) self.assertIsNone(builder._task_store) - def test_init_with_populate_false_task_store_provided(self): + def test_init_with_populate_false_task_store_provided(self) -> None: # Even if populate is false, task_store might still be provided (though not used by build for related_tasks) builder = SimpleRequestContextBuilder( should_populate_referred_tasks=False, @@ -77,7 +79,7 @@ def test_init_with_populate_false_task_store_provided(self): self.assertFalse(builder._should_populate_referred_tasks) self.assertEqual(builder._task_store, self.mock_task_store) - async def test_build_basic_context_no_populate(self): + async def test_build_basic_context_no_populate(self) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=False, task_store=self.mock_task_store, @@ -117,7 +119,7 @@ async def test_build_basic_context_no_populate(self): self.assertEqual(request_context.related_tasks, []) # Initialized to [] self.mock_task_store.get.assert_not_called() - async def test_build_populate_true_with_reference_task_ids(self): + async def test_build_populate_true_with_reference_task_ids(self) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=True, task_store=self.mock_task_store ) @@ -167,7 +169,7 @@ async def get_side_effect(task_id): self.assertIn(mock_ref_task1, request_context.related_tasks) self.assertIn(mock_ref_task3, request_context.related_tasks) - async def test_build_populate_true_params_none(self): + async def test_build_populate_true_params_none(self) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=True, task_store=self.mock_task_store ) @@ -182,7 +184,9 @@ async def test_build_populate_true_params_none(self): self.assertEqual(request_context.related_tasks, []) self.mock_task_store.get.assert_not_called() - async def test_build_populate_true_reference_ids_empty_or_none(self): + async def test_build_populate_true_reference_ids_empty_or_none( + self, + ) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=True, task_store=self.mock_task_store ) @@ -224,7 +228,7 @@ async def test_build_populate_true_reference_ids_empty_or_none(self): self.assertEqual(request_context_none.related_tasks, []) self.mock_task_store.get.assert_not_called() - async def test_build_populate_true_task_store_none(self): + async def test_build_populate_true_task_store_none(self) -> None: # This scenario might be prevented by constructor logic if should_populate_referred_tasks is True, # but testing defensively. The builder might allow task_store=None if it's set post-init, # or if constructor logic changes. Current SimpleRequestContextBuilder takes it at init. @@ -249,7 +253,7 @@ async def test_build_populate_true_task_store_none(self): self.assertEqual(request_context.related_tasks, []) # No mock_task_store to check calls on, this test is mostly for graceful handling. - async def test_build_populate_false_with_reference_task_ids(self): + async def test_build_populate_false_with_reference_task_ids(self) -> None: builder = SimpleRequestContextBuilder( should_populate_referred_tasks=False, task_store=self.mock_task_store, diff --git a/tests/server/events/test_event_queue.py b/tests/server/events/test_event_queue.py index 18ebf72b..0ff966cc 100644 --- a/tests/server/events/test_event_queue.py +++ b/tests/server/events/test_event_queue.py @@ -45,20 +45,20 @@ def event_queue() -> EventQueue: return EventQueue() -def test_constructor_default_max_queue_size(): +def test_constructor_default_max_queue_size() -> None: """Test that the queue is created with the default max size.""" eq = EventQueue() assert eq.queue.maxsize == DEFAULT_MAX_QUEUE_SIZE -def test_constructor_max_queue_size(): +def test_constructor_max_queue_size() -> None: """Test that the asyncio.Queue is created with the specified max_queue_size.""" custom_size = 123 eq = EventQueue(max_queue_size=custom_size) assert eq.queue.maxsize == custom_size -def test_constructor_invalid_max_queue_size(): +def test_constructor_invalid_max_queue_size() -> None: """Test that a ValueError is raised for non-positive max_queue_size.""" with pytest.raises( ValueError, match='max_queue_size must be greater than 0' @@ -170,7 +170,7 @@ async def test_enqueue_event_propagates_to_children( @pytest.mark.asyncio async def test_enqueue_event_when_closed( - event_queue: EventQueue, expected_queue_closed_exception + event_queue: EventQueue, expected_queue_closed_exception: type[Exception] ) -> None: """Test that no event is enqueued if the parent queue is closed.""" await event_queue.close() # Close the queue first @@ -199,7 +199,7 @@ async def test_enqueue_event_when_closed( @pytest.fixture -def expected_queue_closed_exception(): +def expected_queue_closed_exception() -> type[Exception]: if sys.version_info < (3, 13): return asyncio.QueueEmpty return asyncio.QueueShutDown @@ -207,7 +207,7 @@ def expected_queue_closed_exception(): @pytest.mark.asyncio async def test_dequeue_event_closed_and_empty_no_wait( - event_queue: EventQueue, expected_queue_closed_exception + event_queue: EventQueue, expected_queue_closed_exception: type[Exception] ) -> None: """Test dequeue_event raises QueueEmpty when closed, empty, and no_wait=True.""" await event_queue.close() @@ -222,7 +222,7 @@ async def test_dequeue_event_closed_and_empty_no_wait( @pytest.mark.asyncio async def test_dequeue_event_closed_and_empty_waits_then_raises( - event_queue: EventQueue, expected_queue_closed_exception + event_queue: EventQueue, expected_queue_closed_exception: type[Exception] ) -> None: """Test dequeue_event raises QueueEmpty eventually when closed, empty, and no_wait=False.""" await event_queue.close() @@ -409,7 +409,6 @@ async def test_close_immediate_propagates_to_children( event_queue: EventQueue, ) -> None: """Test that immediate parameter is propagated to child queues.""" - child_queue = event_queue.tap() # Add events to both parent and child @@ -430,7 +429,6 @@ async def test_close_immediate_propagates_to_children( @pytest.mark.asyncio async def test_clear_events_current_queue_only(event_queue: EventQueue) -> None: """Test clear_events clears only the current queue when clear_child_queues=False.""" - child_queue = event_queue.tap() event1 = Message(**MESSAGE_PAYLOAD) event2 = Task(**MINIMAL_TASK) @@ -454,7 +452,6 @@ async def test_clear_events_current_queue_only(event_queue: EventQueue) -> None: @pytest.mark.asyncio async def test_clear_events_with_children(event_queue: EventQueue) -> None: """Test clear_events clears both current queue and child queues.""" - # Create child queues and add events child_queue1 = event_queue.tap() child_queue2 = event_queue.tap() diff --git a/tests/server/events/test_inmemory_queue_manager.py b/tests/server/events/test_inmemory_queue_manager.py index 3fb8f4c7..b51334a9 100644 --- a/tests/server/events/test_inmemory_queue_manager.py +++ b/tests/server/events/test_inmemory_queue_manager.py @@ -14,33 +14,38 @@ class TestInMemoryQueueManager: @pytest.fixture - def queue_manager(self): + def queue_manager(self) -> InMemoryQueueManager: """Fixture to create a fresh InMemoryQueueManager for each test.""" return InMemoryQueueManager() @pytest.fixture - def event_queue(self): + def event_queue(self) -> MagicMock: """Fixture to create a mock EventQueue.""" queue = MagicMock(spec=EventQueue) + # Mock the tap method to return itself queue.tap.return_value = queue return queue @pytest.mark.asyncio - async def test_init(self, queue_manager): + async def test_init(self, queue_manager: InMemoryQueueManager) -> None: """Test that the InMemoryQueueManager initializes with empty task queue and a lock.""" assert queue_manager._task_queue == {} assert isinstance(queue_manager._lock, asyncio.Lock) @pytest.mark.asyncio - async def test_add_new_queue(self, queue_manager, event_queue): + async def test_add_new_queue( + self, queue_manager: InMemoryQueueManager, event_queue: MagicMock + ) -> None: """Test adding a new queue to the manager.""" task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) assert queue_manager._task_queue[task_id] == event_queue @pytest.mark.asyncio - async def test_add_existing_queue(self, queue_manager, event_queue): + async def test_add_existing_queue( + self, queue_manager: InMemoryQueueManager, event_queue: MagicMock + ) -> None: """Test adding a queue with an existing task_id raises TaskQueueExists.""" task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) @@ -49,7 +54,9 @@ async def test_add_existing_queue(self, queue_manager, event_queue): await queue_manager.add(task_id, event_queue) @pytest.mark.asyncio - async def test_get_existing_queue(self, queue_manager, event_queue): + async def test_get_existing_queue( + self, queue_manager: InMemoryQueueManager, event_queue: MagicMock + ) -> None: """Test getting an existing queue returns the queue.""" task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) @@ -58,13 +65,17 @@ async def test_get_existing_queue(self, queue_manager, event_queue): assert result == event_queue @pytest.mark.asyncio - async def test_get_nonexistent_queue(self, queue_manager): + async def test_get_nonexistent_queue( + self, queue_manager: InMemoryQueueManager + ) -> None: """Test getting a nonexistent queue returns None.""" result = await queue_manager.get('nonexistent_task_id') assert result is None @pytest.mark.asyncio - async def test_tap_existing_queue(self, queue_manager, event_queue): + async def test_tap_existing_queue( + self, queue_manager: InMemoryQueueManager, event_queue: MagicMock + ) -> None: """Test tapping an existing queue returns the tapped queue.""" task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) @@ -74,13 +85,17 @@ async def test_tap_existing_queue(self, queue_manager, event_queue): event_queue.tap.assert_called_once() @pytest.mark.asyncio - async def test_tap_nonexistent_queue(self, queue_manager): + async def test_tap_nonexistent_queue( + self, queue_manager: InMemoryQueueManager + ) -> None: """Test tapping a nonexistent queue returns None.""" result = await queue_manager.tap('nonexistent_task_id') assert result is None @pytest.mark.asyncio - async def test_close_existing_queue(self, queue_manager, event_queue): + async def test_close_existing_queue( + self, queue_manager: InMemoryQueueManager, event_queue: MagicMock + ) -> None: """Test closing an existing queue removes it from the manager.""" task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) @@ -89,13 +104,17 @@ async def test_close_existing_queue(self, queue_manager, event_queue): assert task_id not in queue_manager._task_queue @pytest.mark.asyncio - async def test_close_nonexistent_queue(self, queue_manager): + async def test_close_nonexistent_queue( + self, queue_manager: InMemoryQueueManager + ) -> None: """Test closing a nonexistent queue raises NoTaskQueue.""" with pytest.raises(NoTaskQueue): await queue_manager.close('nonexistent_task_id') @pytest.mark.asyncio - async def test_create_or_tap_new_queue(self, queue_manager): + async def test_create_or_tap_new_queue( + self, queue_manager: InMemoryQueueManager + ) -> None: """Test create_or_tap with a new task_id creates and returns a new queue.""" task_id = 'test_task_id' @@ -105,8 +124,8 @@ async def test_create_or_tap_new_queue(self, queue_manager): @pytest.mark.asyncio async def test_create_or_tap_existing_queue( - self, queue_manager, event_queue - ): + self, queue_manager: InMemoryQueueManager, event_queue: MagicMock + ) -> None: """Test create_or_tap with an existing task_id taps and returns the existing queue.""" task_id = 'test_task_id' await queue_manager.add(task_id, event_queue) @@ -117,7 +136,9 @@ async def test_create_or_tap_existing_queue( event_queue.tap.assert_called_once() @pytest.mark.asyncio - async def test_concurrency(self, queue_manager): + async def test_concurrency( + self, queue_manager: InMemoryQueueManager + ) -> None: """Test concurrent access to the queue manager.""" async def add_task(task_id): diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 05af6cda..26f923c1 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -61,7 +61,7 @@ async def test_send_message_success( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, -): +) -> None: """Test successful SendMessage call.""" request_proto = a2a_pb2.SendMessageRequest( request=a2a_pb2.Message(message_id='msg-1') @@ -86,7 +86,7 @@ async def test_send_message_server_error( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, -): +) -> None: """Test SendMessage call when handler raises a ServerError.""" request_proto = a2a_pb2.SendMessageRequest() error = ServerError(error=types.InvalidParamsError(message='Bad params')) @@ -104,7 +104,7 @@ async def test_get_task_success( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, -): +) -> None: """Test successful GetTask call.""" request_proto = a2a_pb2.GetTaskRequest(name='tasks/task-1') response_model = types.Task( @@ -126,7 +126,7 @@ async def test_get_task_not_found( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, -): +) -> None: """Test GetTask call when task is not found.""" request_proto = a2a_pb2.GetTaskRequest(name='tasks/task-1') mock_request_handler.on_get_task.return_value = None @@ -143,7 +143,7 @@ async def test_cancel_task_server_error( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, -): +) -> None: """Test CancelTask call when handler raises ServerError.""" request_proto = a2a_pb2.CancelTaskRequest(name='tasks/task-1') error = ServerError(error=types.TaskNotCancelableError()) @@ -162,7 +162,7 @@ async def test_send_streaming_message( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, -): +) -> None: """Test successful SendStreamingMessage call.""" async def mock_stream(): @@ -192,7 +192,7 @@ async def test_get_agent_card( grpc_handler: GrpcHandler, sample_agent_card: types.AgentCard, mock_grpc_context: AsyncMock, -): +) -> None: """Test GetAgentCard call.""" request_proto = a2a_pb2.GetAgentCardRequest() response = await grpc_handler.GetAgentCard(request_proto, mock_grpc_context) @@ -206,7 +206,7 @@ async def test_get_agent_card_with_modifier( mock_request_handler: AsyncMock, sample_agent_card: types.AgentCard, mock_grpc_context: AsyncMock, -): +) -> None: """Test GetAgentCard call with a card_modifier.""" def modifier(card: types.AgentCard) -> types.AgentCard: @@ -299,10 +299,10 @@ async def test_abort_context_error_mapping( # noqa: PLR0913 grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, - server_error, - grpc_status_code, - error_message_part, -): + server_error: ServerError, + grpc_status_code: grpc.StatusCode, + error_message_part: str, +) -> None: mock_request_handler.on_get_task.side_effect = server_error request_proto = a2a_pb2.GetTaskRequest(name='tasks/any') await grpc_handler.GetTask(request_proto, mock_grpc_context) @@ -320,7 +320,7 @@ async def test_send_message_with_extensions( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, - ): + ) -> None: mock_grpc_context.invocation_metadata = grpc.aio.Metadata( (HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'bar'), @@ -360,7 +360,7 @@ async def test_send_message_with_comma_separated_extensions( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, - ): + ) -> None: mock_grpc_context.invocation_metadata = grpc.aio.Metadata( (HTTP_EXTENSION_HEADER, 'foo ,, bar,'), (HTTP_EXTENSION_HEADER, 'baz , bar'), @@ -385,7 +385,7 @@ async def test_send_streaming_message_with_extensions( grpc_handler: GrpcHandler, mock_request_handler: AsyncMock, mock_grpc_context: AsyncMock, - ): + ) -> None: mock_grpc_context.invocation_metadata = grpc.aio.Metadata( (HTTP_EXTENSION_HEADER, 'foo'), (HTTP_EXTENSION_HEADER, 'bar'), diff --git a/tests/server/request_handlers/test_response_helpers.py b/tests/server/request_handlers/test_response_helpers.py index 96e79e51..36de78e6 100644 --- a/tests/server/request_handlers/test_response_helpers.py +++ b/tests/server/request_handlers/test_response_helpers.py @@ -22,7 +22,7 @@ class TestResponseHelpers(unittest.TestCase): - def test_build_error_response_with_a2a_error(self): + def test_build_error_response_with_a2a_error(self) -> None: request_id = 'req1' specific_error = TaskNotFoundError() a2a_error = A2AError(root=specific_error) # Correctly wrap @@ -36,7 +36,7 @@ def test_build_error_response_with_a2a_error(self): response_wrapper.root.error, specific_error ) # build_error_response unwraps A2AError - def test_build_error_response_with_jsonrpc_error(self): + def test_build_error_response_with_jsonrpc_error(self) -> None: request_id = 123 json_rpc_error = InvalidParamsError( message='Custom invalid params' @@ -51,7 +51,7 @@ def test_build_error_response_with_jsonrpc_error(self): response_wrapper.root.error, json_rpc_error ) # No .root access for json_rpc_error - def test_build_error_response_with_a2a_wrapping_jsonrpc_error(self): + def test_build_error_response_with_a2a_wrapping_jsonrpc_error(self) -> None: request_id = 'req_wrap' specific_jsonrpc_error = InvalidParamsError(message='Detail error') a2a_error_wrapping = A2AError( @@ -65,7 +65,7 @@ def test_build_error_response_with_a2a_wrapping_jsonrpc_error(self): self.assertEqual(response_wrapper.root.id, request_id) self.assertEqual(response_wrapper.root.error, specific_jsonrpc_error) - def test_build_error_response_with_request_id_string(self): + def test_build_error_response_with_request_id_string(self) -> None: request_id = 'string_id_test' # Pass an A2AError-wrapped specific error for consistency with how build_error_response handles A2AError error = A2AError(root=TaskNotFoundError()) @@ -75,7 +75,7 @@ def test_build_error_response_with_request_id_string(self): self.assertIsInstance(response_wrapper.root, JSONRPCErrorResponse) self.assertEqual(response_wrapper.root.id, request_id) - def test_build_error_response_with_request_id_int(self): + def test_build_error_response_with_request_id_int(self) -> None: request_id = 456 error = A2AError(root=TaskNotFoundError()) response_wrapper = build_error_response( @@ -84,7 +84,7 @@ def test_build_error_response_with_request_id_int(self): self.assertIsInstance(response_wrapper.root, JSONRPCErrorResponse) self.assertEqual(response_wrapper.root.id, request_id) - def test_build_error_response_with_request_id_none(self): + def test_build_error_response_with_request_id_none(self) -> None: request_id = None error = A2AError(root=TaskNotFoundError()) response_wrapper = build_error_response( @@ -93,7 +93,9 @@ def test_build_error_response_with_request_id_none(self): self.assertIsInstance(response_wrapper.root, JSONRPCErrorResponse) self.assertIsNone(response_wrapper.root.id) - def _create_sample_task(self, task_id='task123', context_id='ctx456'): + def _create_sample_task( + self, task_id: str = 'task123', context_id: str = 'ctx456' + ) -> Task: return Task( id=task_id, context_id=context_id, @@ -101,7 +103,7 @@ def _create_sample_task(self, task_id='task123', context_id='ctx456'): history=[], ) - def test_prepare_response_object_successful_response(self): + def test_prepare_response_object_successful_response(self) -> None: request_id = 'req_success' task_result = self._create_sample_task() response_wrapper = prepare_response_object( @@ -119,7 +121,7 @@ def test_prepare_response_object_successful_response(self): @patch('a2a.server.request_handlers.response_helpers.build_error_response') def test_prepare_response_object_with_a2a_error_instance( self, mock_build_error - ): + ) -> None: request_id = 'req_a2a_err' specific_error = TaskNotFoundError() a2a_error_instance = A2AError( @@ -150,7 +152,7 @@ def test_prepare_response_object_with_a2a_error_instance( @patch('a2a.server.request_handlers.response_helpers.build_error_response') def test_prepare_response_object_with_jsonrpcerror_base_instance( self, mock_build_error - ): + ) -> None: request_id = 789 # Use the base JSONRPCError class instance json_rpc_base_error = JSONRPCError( @@ -180,7 +182,7 @@ def test_prepare_response_object_with_jsonrpcerror_base_instance( @patch('a2a.server.request_handlers.response_helpers.build_error_response') def test_prepare_response_object_specific_error_model_as_unexpected( self, mock_build_error - ): + ) -> None: request_id = 'req_specific_unexpected' # Pass a specific error model (like TaskNotFoundError) directly, NOT wrapped in A2AError # This should be treated as an "unexpected" type by prepare_response_object's current logic @@ -219,7 +221,7 @@ def test_prepare_response_object_specific_error_model_as_unexpected( self.assertEqual(args[2], GetTaskResponse) self.assertEqual(response_wrapper, mock_final_wrapped_response) - def test_prepare_response_object_with_request_id_string(self): + def test_prepare_response_object_with_request_id_string(self) -> None: request_id = 'string_id_prep' task_result = self._create_sample_task() response_wrapper = prepare_response_object( @@ -232,7 +234,7 @@ def test_prepare_response_object_with_request_id_string(self): self.assertIsInstance(response_wrapper.root, GetTaskSuccessResponse) self.assertEqual(response_wrapper.root.id, request_id) - def test_prepare_response_object_with_request_id_int(self): + def test_prepare_response_object_with_request_id_int(self) -> None: request_id = 101112 task_result = self._create_sample_task() response_wrapper = prepare_response_object( @@ -245,7 +247,7 @@ def test_prepare_response_object_with_request_id_int(self): self.assertIsInstance(response_wrapper.root, GetTaskSuccessResponse) self.assertEqual(response_wrapper.root.id, request_id) - def test_prepare_response_object_with_request_id_none(self): + def test_prepare_response_object_with_request_id_none(self) -> None: request_id = None task_result = self._create_sample_task() response_wrapper = prepare_response_object( diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index 93baf0d3..375ed97c 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -17,7 +17,9 @@ # logging.disable(logging.CRITICAL) -def create_sample_task(task_id='task123', status_state=TaskState.completed): +def create_sample_task( + task_id: str = 'task123', status_state: TaskState = TaskState.completed +) -> Task: return Task( id=task_id, context_id='ctx456', @@ -26,23 +28,25 @@ def create_sample_task(task_id='task123', status_state=TaskState.completed): def create_sample_push_config( - url='http://example.com/callback', config_id='cfg1', token=None -): + url: str = 'http://example.com/callback', + config_id: str = 'cfg1', + token: str | None = None, +) -> PushNotificationConfig: return PushNotificationConfig(id=config_id, url=url, token=token) class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase): - def setUp(self): + def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) self.config_store = InMemoryPushNotificationConfigStore() self.notifier = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.config_store ) # Corrected argument name - def test_constructor_stores_client(self): + def test_constructor_stores_client(self) -> None: self.assertEqual(self.notifier._client, self.mock_httpx_client) - async def test_set_info_adds_new_config(self): + async def test_set_info_adds_new_config(self) -> None: task_id = 'task_new' config = create_sample_push_config(url='http://new.url/callback') @@ -53,7 +57,7 @@ async def test_set_info_adds_new_config(self): self.config_store._push_notification_infos[task_id], [config] ) - async def test_set_info_appends_to_existing_config(self): + async def test_set_info_appends_to_existing_config(self) -> None: task_id = 'task_update' initial_config = create_sample_push_config( url='http://initial.url/callback', config_id='cfg_initial' @@ -75,7 +79,7 @@ async def test_set_info_appends_to_existing_config(self): updated_config, ) - async def test_set_info_without_config_id(self): + async def test_set_info_without_config_id(self) -> None: task_id = 'task1' initial_config = PushNotificationConfig( url='http://initial.url/callback' @@ -98,7 +102,7 @@ async def test_set_info_without_config_id(self): updated_config.url, ) - async def test_get_info_existing_config(self): + async def test_get_info_existing_config(self) -> None: task_id = 'task_get_exist' config = create_sample_push_config(url='http://get.this/callback') await self.config_store.set_info(task_id, config) @@ -106,12 +110,12 @@ async def test_get_info_existing_config(self): retrieved_config = await self.config_store.get_info(task_id) self.assertEqual(retrieved_config, [config]) - async def test_get_info_non_existent_config(self): + async def test_get_info_non_existent_config(self) -> None: task_id = 'task_get_non_exist' retrieved_config = await self.config_store.get_info(task_id) assert retrieved_config == [] - async def test_delete_info_existing_config(self): + async def test_delete_info_existing_config(self) -> None: task_id = 'task_delete_exist' config = create_sample_push_config(url='http://delete.this/callback') await self.config_store.set_info(task_id, config) @@ -120,7 +124,7 @@ async def test_delete_info_existing_config(self): await self.config_store.delete_info(task_id, config_id=config.id) self.assertNotIn(task_id, self.config_store._push_notification_infos) - async def test_delete_info_non_existent_config(self): + async def test_delete_info_non_existent_config(self) -> None: task_id = 'task_delete_non_exist' # Ensure it doesn't raise an error try: @@ -133,7 +137,7 @@ async def test_delete_info_non_existent_config(self): task_id, self.config_store._push_notification_infos ) # Should still not be there - async def test_send_notification_success(self): + async def test_send_notification_success(self) -> None: task_id = 'task_send_success' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config(url='http://notify.me/here') @@ -158,7 +162,7 @@ async def test_send_notification_success(self): ) # auth is not passed by current implementation mock_response.raise_for_status.assert_called_once() - async def test_send_notification_with_token_success(self): + async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config( @@ -189,7 +193,7 @@ async def test_send_notification_with_token_success(self): ) # auth is not passed by current implementation mock_response.raise_for_status.assert_called_once() - async def test_send_notification_no_config(self): + async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' task_data = create_sample_task(task_id=task_id) @@ -200,7 +204,7 @@ async def test_send_notification_no_config(self): @patch('a2a.server.tasks.base_push_notification_sender.logger') async def test_send_notification_http_status_error( self, mock_logger: MagicMock - ): + ) -> None: task_id = 'task_send_http_err' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config(url='http://notify.me/http_error') @@ -230,7 +234,7 @@ async def test_send_notification_http_status_error( @patch('a2a.server.tasks.base_push_notification_sender.logger') async def test_send_notification_request_error( self, mock_logger: MagicMock - ): + ) -> None: task_id = 'task_send_req_err' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config(url='http://notify.me/req_error') @@ -249,7 +253,9 @@ async def test_send_notification_request_error( ) @patch('a2a.server.tasks.base_push_notification_sender.logger') - async def test_send_notification_with_auth(self, mock_logger: MagicMock): + async def test_send_notification_with_auth( + self, mock_logger: MagicMock + ) -> None: task_id = 'task_send_auth' task_data = create_sample_task(task_id=task_id) auth_info = ('user', 'pass') diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index fb398670..a3272c2c 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -15,7 +15,9 @@ ) -def create_sample_task(task_id='task123', status_state=TaskState.completed): +def create_sample_task( + task_id: str = 'task123', status_state: TaskState = TaskState.completed +) -> Task: return Task( id=task_id, context_id='ctx456', @@ -24,13 +26,15 @@ def create_sample_task(task_id='task123', status_state=TaskState.completed): def create_sample_push_config( - url='http://example.com/callback', config_id='cfg1', token=None -): + url: str = 'http://example.com/callback', + config_id: str = 'cfg1', + token: str | None = None, +) -> PushNotificationConfig: return PushNotificationConfig(id=config_id, url=url, token=token) class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase): - def setUp(self): + def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) self.mock_config_store = AsyncMock() self.sender = BasePushNotificationSender( @@ -38,11 +42,11 @@ def setUp(self): config_store=self.mock_config_store, ) - def test_constructor_stores_client_and_config_store(self): + def test_constructor_stores_client_and_config_store(self) -> None: self.assertEqual(self.sender._client, self.mock_httpx_client) self.assertEqual(self.sender._config_store, self.mock_config_store) - async def test_send_notification_success(self): + async def test_send_notification_success(self) -> None: task_id = 'task_send_success' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config(url='http://notify.me/here') @@ -64,7 +68,7 @@ async def test_send_notification_success(self): ) mock_response.raise_for_status.assert_called_once() - async def test_send_notification_with_token_success(self): + async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config( @@ -88,7 +92,7 @@ async def test_send_notification_with_token_success(self): ) mock_response.raise_for_status.assert_called_once() - async def test_send_notification_no_config(self): + async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' task_data = create_sample_task(task_id=task_id) self.mock_config_store.get_info.return_value = [] @@ -101,7 +105,7 @@ async def test_send_notification_no_config(self): @patch('a2a.server.tasks.base_push_notification_sender.logger') async def test_send_notification_http_status_error( self, mock_logger: MagicMock - ): + ) -> None: task_id = 'task_send_http_err' task_data = create_sample_task(task_id=task_id) config = create_sample_push_config(url='http://notify.me/http_error') @@ -125,7 +129,7 @@ async def test_send_notification_http_status_error( ) mock_logger.exception.assert_called_once() - async def test_send_notification_multiple_configs(self): + async def test_send_notification_multiple_configs(self) -> None: task_id = 'task_multiple_configs' task_data = create_sample_task(task_id=task_id) config1 = create_sample_push_config( diff --git a/tests/server/tasks/test_result_aggregator.py b/tests/server/tasks/test_result_aggregator.py index da77e693..bc970246 100644 --- a/tests/server/tasks/test_result_aggregator.py +++ b/tests/server/tasks/test_result_aggregator.py @@ -4,6 +4,8 @@ from collections.abc import AsyncIterator from unittest.mock import AsyncMock, MagicMock, patch +from typing_extensions import override + from a2a.server.events.event_consumer import EventConsumer from a2a.server.tasks.result_aggregator import ResultAggregator from a2a.server.tasks.task_manager import TaskManager @@ -21,8 +23,8 @@ # Helper to create a simple message def create_sample_message( - content='test message', msg_id='msg1', role=Role.user -): + content: str = 'test message', msg_id: str = 'msg1', role: Role = Role.user +) -> Message: return Message( message_id=msg_id, role=role, @@ -32,8 +34,10 @@ def create_sample_message( # Helper to create a simple task def create_sample_task( - task_id='task1', status_state=TaskState.submitted, context_id='ctx1' -): + task_id: str = 'task1', + status_state: TaskState = TaskState.submitted, + context_id: str = 'ctx1', +) -> Task: return Task( id=task_id, context_id=context_id, @@ -43,8 +47,10 @@ def create_sample_task( # Helper to create a TaskStatusUpdateEvent def create_sample_status_update( - task_id='task1', status_state=TaskState.working, context_id='ctx1' -): + task_id: str = 'task1', + status_state: TaskState = TaskState.working, + context_id: str = 'ctx1', +) -> TaskStatusUpdateEvent: return TaskStatusUpdateEvent( task_id=task_id, context_id=context_id, @@ -54,7 +60,8 @@ def create_sample_status_update( class TestResultAggregator(unittest.IsolatedAsyncioTestCase): - def setUp(self): + @override + def setUp(self) -> None: self.mock_task_manager = AsyncMock(spec=TaskManager) self.mock_event_consumer = AsyncMock(spec=EventConsumer) self.aggregator = ResultAggregator( @@ -62,17 +69,17 @@ def setUp(self): # event_consumer is not passed to constructor ) - def test_init_stores_task_manager(self): + def test_init_stores_task_manager(self) -> None: self.assertEqual(self.aggregator.task_manager, self.mock_task_manager) # event_consumer is also stored, can be tested if needed, but focus is on task_manager per req. - async def test_current_result_property_with_message_set(self): + async def test_current_result_property_with_message_set(self) -> None: sample_message = create_sample_message(content='hola') self.aggregator._message = sample_message self.assertEqual(await self.aggregator.current_result, sample_message) self.mock_task_manager.get_task.assert_not_called() - async def test_current_result_property_with_message_none(self): + async def test_current_result_property_with_message_none(self) -> None: expected_task = create_sample_task(task_id='task_from_tm') self.mock_task_manager.get_task.return_value = expected_task self.aggregator._message = None @@ -82,7 +89,7 @@ async def test_current_result_property_with_message_none(self): self.assertEqual(current_res, expected_task) self.mock_task_manager.get_task.assert_called_once() - async def test_consume_and_emit(self): + async def test_consume_and_emit(self) -> None: event1 = create_sample_message(content='event one', msg_id='e1') event2 = create_sample_task( task_id='task_event', status_state=TaskState.working @@ -120,7 +127,7 @@ async def mock_consume_generator(): self.mock_task_manager.process.assert_any_call(event2) self.mock_task_manager.process.assert_any_call(event3) - async def test_consume_all_only_message_event(self): + async def test_consume_all_only_message_event(self) -> None: sample_message = create_sample_message(content='final message') async def mock_consume_generator(): @@ -136,7 +143,7 @@ async def mock_consume_generator(): self.mock_task_manager.process.assert_not_called() # Process is not called if message is returned directly self.mock_task_manager.get_task.assert_not_called() # Should not be called if message is returned - async def test_consume_all_other_event_types(self): + async def test_consume_all_other_event_types(self) -> None: task_event = create_sample_task(task_id='task_other_event') status_update_event = create_sample_status_update( task_id='task_other_event', status_state=TaskState.completed @@ -162,7 +169,7 @@ async def mock_consume_generator(): self.mock_task_manager.process.assert_any_call(status_update_event) self.mock_task_manager.get_task.assert_called_once() - async def test_consume_all_empty_stream(self): + async def test_consume_all_empty_stream(self) -> None: empty_task_state = create_sample_task(task_id='empty_stream_task') async def mock_consume_generator(): @@ -180,7 +187,7 @@ async def mock_consume_generator(): self.mock_task_manager.process.assert_not_called() self.mock_task_manager.get_task.assert_called_once() - async def test_consume_all_event_consumer_exception(self): + async def test_consume_all_event_consumer_exception(self) -> None: class TestException(Exception): pass @@ -206,7 +213,7 @@ async def raiser_gen(): ) self.mock_task_manager.get_task.assert_not_called() - async def test_consume_and_break_on_message(self): + async def test_consume_and_break_on_message(self) -> None: sample_message = create_sample_message(content='interrupt message') event_after = create_sample_task('task_after_msg') @@ -234,7 +241,7 @@ async def mock_consume_generator(): @patch('asyncio.create_task') async def test_consume_and_break_on_auth_required_task_event( self, mock_create_task: MagicMock - ): + ) -> None: auth_task = create_sample_task( task_id='auth_task', status_state=TaskState.auth_required ) @@ -286,7 +293,7 @@ async def mock_consume_generator(): @patch('asyncio.create_task') async def test_consume_and_break_on_auth_required_status_update_event( self, mock_create_task: MagicMock - ): + ) -> None: auth_status_update = create_sample_status_update( task_id='auth_status_task', status_state=TaskState.auth_required ) @@ -325,7 +332,7 @@ async def mock_consume_generator(): self.aggregator._continue_consuming.call_args[0][0], AsyncIterator ) - async def test_consume_and_break_completes_normally(self): + async def test_consume_and_break_completes_normally(self) -> None: event1 = create_sample_message('event one normal', msg_id='n1') event2 = create_sample_task('normal_task') final_task_state = create_sample_task( @@ -357,7 +364,7 @@ async def mock_consume_generator(): self.mock_task_manager.process.assert_not_called() self.mock_task_manager.get_task.assert_not_called() - async def test_consume_and_break_event_consumer_exception(self): + async def test_consume_and_break_event_consumer_exception(self) -> None: class TestInterruptException(Exception): pass @@ -387,7 +394,7 @@ async def raiser_gen_interrupt(): @patch('asyncio.create_task') async def test_consume_and_break_non_blocking( self, mock_create_task: MagicMock - ): + ) -> None: """Test that with blocking=False, the method returns after the first event.""" first_event = create_sample_task('non_blocking_task') event_after = create_sample_message('should be consumed later') @@ -425,7 +432,7 @@ async def mock_consume_generator(): @patch('asyncio.create_task') # To verify _continue_consuming is called async def test_continue_consuming_processes_remaining_events( self, mock_create_task: MagicMock - ): + ) -> None: # This test focuses on verifying that if an interrupt occurs, # the events *after* the interrupting one are processed by _continue_consuming. diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index a8de65e3..891f8a10 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -20,13 +20,13 @@ @pytest.fixture -def event_queue(): +def event_queue() -> AsyncMock: """Create a mock event queue for testing.""" return AsyncMock(spec=EventQueue) @pytest.fixture -def task_updater(event_queue): +def task_updater(event_queue: AsyncMock) -> TaskUpdater: """Create a TaskUpdater instance for testing.""" return TaskUpdater( event_queue=event_queue, @@ -36,7 +36,7 @@ def task_updater(event_queue): @pytest.fixture -def sample_message(): +def sample_message() -> Message: """Create a sample message for testing.""" return Message( role=Role.agent, @@ -48,12 +48,12 @@ def sample_message(): @pytest.fixture -def sample_parts(): +def sample_parts() -> list[Part]: """Create sample parts for testing.""" return [Part(root=TextPart(text='Test part'))] -def test_init(event_queue): +def test_init(event_queue: AsyncMock) -> None: """Test that TaskUpdater initializes correctly.""" task_updater = TaskUpdater( event_queue=event_queue, @@ -67,7 +67,9 @@ def test_init(event_queue): @pytest.mark.asyncio -async def test_update_status_without_message(task_updater, event_queue): +async def test_update_status_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test updating status without a message.""" await task_updater.update_status(TaskState.working) @@ -84,8 +86,8 @@ async def test_update_status_without_message(task_updater, event_queue): @pytest.mark.asyncio async def test_update_status_with_message( - task_updater, event_queue, sample_message -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test updating status with a message.""" await task_updater.update_status(TaskState.working, message=sample_message) @@ -101,7 +103,9 @@ async def test_update_status_with_message( @pytest.mark.asyncio -async def test_update_status_final(task_updater, event_queue): +async def test_update_status_final( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test updating status with final=True.""" await task_updater.update_status(TaskState.completed, final=True) @@ -115,8 +119,8 @@ async def test_update_status_final(task_updater, event_queue): @pytest.mark.asyncio async def test_add_artifact_with_custom_id_and_name( - task_updater, event_queue, sample_parts -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_parts: list[Part] +) -> None: """Test adding an artifact with a custom ID and name.""" await task_updater.add_artifact( parts=sample_parts, @@ -135,8 +139,8 @@ async def test_add_artifact_with_custom_id_and_name( @pytest.mark.asyncio async def test_add_artifact_generates_id( - task_updater, event_queue, sample_parts -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_parts: list[Part] +) -> None: """Test add_artifact generates an ID if artifact_id is None.""" known_uuid = uuid.UUID('12345678-1234-5678-1234-567812345678') with patch('uuid.uuid4', return_value=known_uuid): @@ -153,7 +157,9 @@ async def test_add_artifact_generates_id( @pytest.mark.asyncio -async def test_add_artifact_generates_custom_id(event_queue, sample_parts): +async def test_add_artifact_generates_custom_id( + event_queue: AsyncMock, sample_parts: list[Part] +) -> None: """Test add_artifact uses a custom ID generator when provided.""" artifact_id_generator = Mock(spec=IDGenerator) artifact_id_generator.generate.return_value = 'custom-artifact-id' @@ -183,8 +189,12 @@ async def test_add_artifact_generates_custom_id(event_queue, sample_parts): ], ) async def test_add_artifact_with_append_last_chunk( - task_updater, event_queue, sample_parts, append_val, last_chunk_val -): + task_updater: TaskUpdater, + event_queue: AsyncMock, + sample_parts: list[Part], + append_val: bool, + last_chunk_val: bool, +) -> None: """Test add_artifact with append and last_chunk flags.""" await task_updater.add_artifact( parts=sample_parts, @@ -204,7 +214,9 @@ async def test_add_artifact_with_append_last_chunk( @pytest.mark.asyncio -async def test_complete_without_message(task_updater, event_queue): +async def test_complete_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as completed without a message.""" await task_updater.complete() @@ -218,7 +230,9 @@ async def test_complete_without_message(task_updater, event_queue): @pytest.mark.asyncio -async def test_complete_with_message(task_updater, event_queue, sample_message): +async def test_complete_with_message( + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as completed with a message.""" await task_updater.complete(message=sample_message) @@ -232,7 +246,9 @@ async def test_complete_with_message(task_updater, event_queue, sample_message): @pytest.mark.asyncio -async def test_submit_without_message(task_updater, event_queue): +async def test_submit_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as submitted without a message.""" await task_updater.submit() @@ -246,7 +262,9 @@ async def test_submit_without_message(task_updater, event_queue): @pytest.mark.asyncio -async def test_submit_with_message(task_updater, event_queue, sample_message): +async def test_submit_with_message( + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as submitted with a message.""" await task_updater.submit(message=sample_message) @@ -260,7 +278,9 @@ async def test_submit_with_message(task_updater, event_queue, sample_message): @pytest.mark.asyncio -async def test_start_work_without_message(task_updater, event_queue): +async def test_start_work_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as working without a message.""" await task_updater.start_work() @@ -275,8 +295,8 @@ async def test_start_work_without_message(task_updater, event_queue): @pytest.mark.asyncio async def test_start_work_with_message( - task_updater, event_queue, sample_message -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as working with a message.""" await task_updater.start_work(message=sample_message) @@ -289,7 +309,9 @@ async def test_start_work_with_message( assert event.status.message == sample_message -def test_new_agent_message(task_updater, sample_parts): +def test_new_agent_message( + task_updater: TaskUpdater, sample_parts: list[Part] +) -> None: """Test creating a new agent message.""" with patch( 'uuid.uuid4', @@ -305,7 +327,9 @@ def test_new_agent_message(task_updater, sample_parts): assert message.metadata is None -def test_new_agent_message_with_metadata(task_updater, sample_parts): +def test_new_agent_message_with_metadata( + task_updater: TaskUpdater, sample_parts: list[Part] +) -> None: """Test creating a new agent message with metadata and final=True.""" metadata = {'key': 'value'} @@ -325,7 +349,9 @@ def test_new_agent_message_with_metadata(task_updater, sample_parts): assert message.metadata == metadata -def test_new_agent_message_with_custom_id_generator(event_queue, sample_parts): +def test_new_agent_message_with_custom_id_generator( + event_queue: AsyncMock, sample_parts: list[Part] +) -> None: """Test creating a new agent message with a custom message ID generator.""" message_id_generator = Mock(spec=IDGenerator) message_id_generator.generate.return_value = 'custom-message-id' @@ -342,7 +368,9 @@ def test_new_agent_message_with_custom_id_generator(event_queue, sample_parts): @pytest.mark.asyncio -async def test_failed_without_message(task_updater, event_queue): +async def test_failed_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as failed without a message.""" await task_updater.failed() @@ -356,7 +384,9 @@ async def test_failed_without_message(task_updater, event_queue): @pytest.mark.asyncio -async def test_failed_with_message(task_updater, event_queue, sample_message): +async def test_failed_with_message( + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as failed with a message.""" await task_updater.failed(message=sample_message) @@ -370,7 +400,9 @@ async def test_failed_with_message(task_updater, event_queue, sample_message): @pytest.mark.asyncio -async def test_reject_without_message(task_updater, event_queue): +async def test_reject_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as rejected without a message.""" await task_updater.reject() @@ -384,7 +416,9 @@ async def test_reject_without_message(task_updater, event_queue): @pytest.mark.asyncio -async def test_reject_with_message(task_updater, event_queue, sample_message): +async def test_reject_with_message( + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as rejected with a message.""" await task_updater.reject(message=sample_message) @@ -398,7 +432,9 @@ async def test_reject_with_message(task_updater, event_queue, sample_message): @pytest.mark.asyncio -async def test_requires_input_without_message(task_updater, event_queue): +async def test_requires_input_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as input required without a message.""" await task_updater.requires_input() @@ -413,8 +449,8 @@ async def test_requires_input_without_message(task_updater, event_queue): @pytest.mark.asyncio async def test_requires_input_with_message( - task_updater, event_queue, sample_message -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as input required with a message.""" await task_updater.requires_input(message=sample_message) @@ -428,7 +464,9 @@ async def test_requires_input_with_message( @pytest.mark.asyncio -async def test_requires_input_final_true(task_updater, event_queue): +async def test_requires_input_final_true( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as input required with final=True.""" await task_updater.requires_input(final=True) @@ -443,8 +481,8 @@ async def test_requires_input_final_true(task_updater, event_queue): @pytest.mark.asyncio async def test_requires_input_with_message_and_final( - task_updater, event_queue, sample_message -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as input required with message and final=True.""" await task_updater.requires_input(message=sample_message, final=True) @@ -458,7 +496,9 @@ async def test_requires_input_with_message_and_final( @pytest.mark.asyncio -async def test_requires_auth_without_message(task_updater, event_queue): +async def test_requires_auth_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as auth required without a message.""" await task_updater.requires_auth() @@ -473,8 +513,8 @@ async def test_requires_auth_without_message(task_updater, event_queue): @pytest.mark.asyncio async def test_requires_auth_with_message( - task_updater, event_queue, sample_message -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as auth required with a message.""" await task_updater.requires_auth(message=sample_message) @@ -488,7 +528,9 @@ async def test_requires_auth_with_message( @pytest.mark.asyncio -async def test_requires_auth_final_true(task_updater, event_queue): +async def test_requires_auth_final_true( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as auth required with final=True.""" await task_updater.requires_auth(final=True) @@ -503,8 +545,8 @@ async def test_requires_auth_final_true(task_updater, event_queue): @pytest.mark.asyncio async def test_requires_auth_with_message_and_final( - task_updater, event_queue, sample_message -): + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as auth required with message and final=True.""" await task_updater.requires_auth(message=sample_message, final=True) @@ -518,7 +560,9 @@ async def test_requires_auth_with_message_and_final( @pytest.mark.asyncio -async def test_cancel_without_message(task_updater, event_queue): +async def test_cancel_without_message( + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: """Test marking a task as cancelled without a message.""" await task_updater.cancel() @@ -532,7 +576,9 @@ async def test_cancel_without_message(task_updater, event_queue): @pytest.mark.asyncio -async def test_cancel_with_message(task_updater, event_queue, sample_message): +async def test_cancel_with_message( + task_updater: TaskUpdater, event_queue: AsyncMock, sample_message: Message +) -> None: """Test marking a task as cancelled with a message.""" await task_updater.cancel(message=sample_message) @@ -547,8 +593,8 @@ async def test_cancel_with_message(task_updater, event_queue, sample_message): @pytest.mark.asyncio async def test_update_status_raises_error_if_terminal_state_reached( - task_updater, event_queue -): + task_updater: TaskUpdater, event_queue: AsyncMock +) -> None: await task_updater.complete() event_queue.reset_mock() with pytest.raises(RuntimeError): @@ -557,7 +603,9 @@ async def test_update_status_raises_error_if_terminal_state_reached( @pytest.mark.asyncio -async def test_concurrent_updates_race_condition(event_queue): +async def test_concurrent_updates_race_condition( + event_queue: AsyncMock, +) -> None: task_updater = TaskUpdater( event_queue=event_queue, task_id='test-task-id', @@ -576,7 +624,9 @@ async def test_concurrent_updates_race_condition(event_queue): @pytest.mark.asyncio -async def test_reject_concurrently_with_complete(event_queue): +async def test_reject_concurrently_with_complete( + event_queue: AsyncMock, +) -> None: """Test for race conditions when reject and complete are called concurrently.""" task_updater = TaskUpdater( event_queue=event_queue, diff --git a/tests/utils/test_telemetry.py b/tests/utils/test_telemetry.py index 5109379b..eae96b19 100644 --- a/tests/utils/test_telemetry.py +++ b/tests/utils/test_telemetry.py @@ -1,6 +1,7 @@ import asyncio -from typing import NoReturn +from collections.abc import Generator +from typing import Any, NoReturn from unittest import mock import pytest @@ -9,12 +10,12 @@ @pytest.fixture -def mock_span(): +def mock_span() -> mock.MagicMock: return mock.MagicMock() @pytest.fixture -def mock_tracer(mock_span): +def mock_tracer(mock_span: mock.MagicMock) -> mock.MagicMock: tracer = mock.MagicMock() tracer.start_as_current_span.return_value.__enter__.return_value = mock_span tracer.start_as_current_span.return_value.__exit__.return_value = False @@ -22,12 +23,14 @@ def mock_tracer(mock_span): @pytest.fixture(autouse=True) -def patch_trace_get_tracer(mock_tracer): +def patch_trace_get_tracer( + mock_tracer: mock.MagicMock, +) -> Generator[None, Any, None]: with mock.patch('opentelemetry.trace.get_tracer', return_value=mock_tracer): yield -def test_trace_function_sync_success(mock_span): +def test_trace_function_sync_success(mock_span: mock.MagicMock) -> None: @trace_function def foo(x, y): return x + y @@ -39,7 +42,7 @@ def foo(x, y): mock_span.record_exception.assert_not_called() -def test_trace_function_sync_exception(mock_span): +def test_trace_function_sync_exception(mock_span: mock.MagicMock) -> None: @trace_function def bar() -> NoReturn: raise ValueError('fail') @@ -50,7 +53,9 @@ def bar() -> NoReturn: mock_span.set_status.assert_any_call(mock.ANY, description='fail') -def test_trace_function_sync_attribute_extractor_called(mock_span): +def test_trace_function_sync_attribute_extractor_called( + mock_span: mock.MagicMock, +) -> None: called = {} def attr_extractor(span, args, kwargs, result, exception) -> None: @@ -67,7 +72,9 @@ def foo() -> int: assert called['called'] -def test_trace_function_sync_attribute_extractor_error_logged(mock_span): +def test_trace_function_sync_attribute_extractor_error_logged( + mock_span: mock.MagicMock, +) -> None: with mock.patch('a2a.utils.telemetry.logger') as logger: def attr_extractor(span, args, kwargs, result, exception) -> NoReturn: @@ -85,7 +92,7 @@ def foo() -> int: @pytest.mark.asyncio -async def test_trace_function_async_success(mock_span): +async def test_trace_function_async_success(mock_span: mock.MagicMock) -> None: @trace_function async def foo(x): await asyncio.sleep(0) @@ -98,7 +105,9 @@ async def foo(x): @pytest.mark.asyncio -async def test_trace_function_async_exception(mock_span): +async def test_trace_function_async_exception( + mock_span: mock.MagicMock, +) -> None: @trace_function async def bar() -> NoReturn: await asyncio.sleep(0) @@ -111,7 +120,9 @@ async def bar() -> NoReturn: @pytest.mark.asyncio -async def test_trace_function_async_attribute_extractor_called(mock_span): +async def test_trace_function_async_attribute_extractor_called( + mock_span: mock.MagicMock, +) -> None: called = {} def attr_extractor(span, args, kwargs, result, exception) -> None: @@ -127,7 +138,9 @@ async def foo() -> int: assert called['called'] -def test_trace_function_with_args_and_attributes(mock_span): +def test_trace_function_with_args_and_attributes( + mock_span: mock.MagicMock, +) -> None: @trace_function(span_name='custom.span', attributes={'foo': 'bar'}) def foo() -> int: return 1 @@ -136,7 +149,7 @@ def foo() -> int: mock_span.set_attribute.assert_any_call('foo', 'bar') -def test_trace_class_exclude_list(mock_span): +def test_trace_class_exclude_list(mock_span: mock.MagicMock) -> None: @trace_class(exclude_list=['skip_me']) class MyClass: def a(self) -> str: @@ -145,7 +158,7 @@ def a(self) -> str: def skip_me(self) -> str: return 'skip' - def __str__(self): + def __str__(self) -> str: return 'str' obj = MyClass() @@ -156,7 +169,7 @@ def __str__(self): assert not hasattr(obj.skip_me, '__wrapped__') -def test_trace_class_include_list(mock_span): +def test_trace_class_include_list(mock_span: mock.MagicMock) -> None: @trace_class(include_list=['only_this']) class MyClass: def only_this(self) -> str: @@ -172,10 +185,10 @@ def not_this(self) -> str: assert not hasattr(obj.not_this, '__wrapped__') -def test_trace_class_dunder_not_traced(mock_span): +def test_trace_class_dunder_not_traced(mock_span: mock.MagicMock) -> None: @trace_class() class MyClass: - def __init__(self): + def __init__(self) -> None: self.x = 1 def foo(self) -> str: