diff --git a/.github/actions/spelling/excludes.txt b/.github/actions/spelling/excludes.txt index 7b4de3ec..380cbd00 100644 --- a/.github/actions/spelling/excludes.txt +++ b/.github/actions/spelling/excludes.txt @@ -81,8 +81,7 @@ \.xz$ \.zip$ ^\.github/actions/spelling/ -^\Q.github/workflows/spelling.yaml\E$ -^\Q.github/workflows/linter.yaml\E$ +^\.github/workflows/ \.gitignore\E$ \.vscode/ noxfile.py diff --git a/.github/linters/.mypy.ini b/.github/linters/.mypy.ini new file mode 100644 index 00000000..80ff63e5 --- /dev/null +++ b/.github/linters/.mypy.ini @@ -0,0 +1,6 @@ +[mypy] +exclude = examples/ +disable_error_code = import-not-found + +[mypy-examples.*] +follow_imports = skip diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index f4c6aeac..e7b13780 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -13,10 +13,9 @@ name: Lint Code Base ############################# # Start the job on all push # ############################# -# on: -# pull_request: -# branches: [main] -on: workflow_dispatch +on: + pull_request: + branches: [main] ############### # Set the Job # @@ -64,3 +63,5 @@ jobs: VALIDATE_TYPESCRIPT_STANDARD: false VALIDATE_GIT_COMMITLINT: false MARKDOWN_CONFIG_FILE: .markdownlint.json + PYTHON_MYPY_CONFIG_FILE: .mypy.ini + FILTER_REGEX_EXCLUDE: "^examples/.*" diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 8cc07133..bf7414cc 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -24,7 +24,7 @@ jobs: - name: Build run: uv build - + - name: Upload distributions uses: actions/upload-artifact@v4 with: @@ -49,6 +49,3 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: packages-dir: dist/ - - - diff --git a/examples/google_adk/birthday_planner/adk_agent_executor.py b/examples/google_adk/birthday_planner/adk_agent_executor.py index 4cf6a569..7b4c396f 100644 --- a/examples/google_adk/birthday_planner/adk_agent_executor.py +++ b/examples/google_adk/birthday_planner/adk_agent_executor.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors import asyncio import logging diff --git a/examples/google_adk/calendar_agent/adk_agent_executor.py b/examples/google_adk/calendar_agent/adk_agent_executor.py index 9182227d..b52d74ad 100644 --- a/examples/google_adk/calendar_agent/adk_agent_executor.py +++ b/examples/google_adk/calendar_agent/adk_agent_executor.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors import asyncio import logging @@ -53,7 +54,7 @@ def __init__(self, runner: Runner, card: AgentCard): def _run_agent( self, session_id, new_message: types.Content - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event]: return self.runner.run_async( session_id=session_id, user_id='self', new_message=new_message ) diff --git a/noxfile.py b/noxfile.py index fd9569fb..d731b6e1 100644 --- a/noxfile.py +++ b/noxfile.py @@ -114,9 +114,15 @@ def format(session): 'pyupgrade', 'autoflake', 'ruff', + 'no_implicit_optional', ) if lint_paths_py: + session.run( + 'no_implicit_optional', + '--use-union-or', + *lint_paths_py, + ) if not format_all: session.run( 'pyupgrade', diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index a0c81ec8..71169dd2 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -3,8 +3,8 @@ from a2a.types import ( InvalidParamsError, Message, - MessageSendParams, MessageSendConfiguration, + MessageSendParams, Task, ) from a2a.utils import get_message_text @@ -82,6 +82,8 @@ def context_id(self) -> str | None: @property def configuration(self) -> MessageSendConfiguration | None: + if not self._params: + return None return self._params.configuration def _check_or_generate_task_id(self) -> None: diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index c3581c4c..17cbd49f 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -128,9 +128,14 @@ async def on_message_send( if task: task = task_manager.update_with_message(params.message, task) if self.should_add_push_info(params): - assert isinstance(self._push_notifier, PushNotifier) # For typechecker - assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker - assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker + assert isinstance(self._push_notifier, PushNotifier) + assert isinstance( + params.configuration, MessageSendConfiguration + ) + assert isinstance( + params.configuration.pushNotificationConfig, + PushNotificationConfig, + ) await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) @@ -193,9 +198,14 @@ async def on_message_send_stream( task = task_manager.update_with_message(params.message, task) if self.should_add_push_info(params): - assert isinstance(self._push_notifier, PushNotifier) # For typechecker - assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker - assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker + assert isinstance(self._push_notifier, PushNotifier) + assert isinstance( + params.configuration, MessageSendConfiguration + ) + assert isinstance( + params.configuration.pushNotificationConfig, + PushNotificationConfig, + ) await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) @@ -324,11 +334,8 @@ async def on_resubscribe_to_task( yield event def should_add_push_info(self, params: MessageSendParams) -> bool: - if ( + return bool( self._push_notifier and params.configuration and params.configuration.pushNotificationConfig - ): - return True - else: - return False + ) diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 1bc9641c..fdecde5a 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -1,5 +1,7 @@ import uuid +from typing import Any + from a2a.server.events import EventQueue from a2a.types import ( Artifact, @@ -42,7 +44,7 @@ def add_artifact( parts: list[Part], artifact_id=str(uuid.uuid4()), name: str | None = None, - metadata: dict[str, any] | None = None, + metadata: dict[str, Any] | None = None, ): """Add an artifact to the task.""" self.event_queue.enqueue_event( @@ -68,11 +70,7 @@ def complete(self, message: Message | None = None): def failed(self, message: Message | None = None): """Mark the task as failed.""" - self.update_status( - TaskState.failed, - message=message, - final=True - ) + self.update_status(TaskState.failed, message=message, final=True) def submit(self, message: Message | None = None): """Mark the task as submitted.""" diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index dfafb7b5..840765c3 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -5,34 +5,34 @@ ) from a2a.utils.helpers import ( append_artifact_to_task, + are_modalities_compatible, build_text_artifact, create_task_obj, - are_modalities_compatible, ) from a2a.utils.message import ( get_message_text, get_text_parts, - new_agent_text_message, new_agent_parts_message, + new_agent_text_message, ) from a2a.utils.task import ( - new_task, completed_task, + new_task, ) __all__ = [ 'append_artifact_to_task', + 'are_modalities_compatible', 'build_text_artifact', + 'completed_task', 'create_task_obj', 'get_message_text', 'get_text_parts', - 'new_agent_text_message', - 'new_task', - 'new_text_artifact', 'new_agent_parts_message', - 'completed_task', + 'new_agent_text_message', 'new_artifact', 'new_data_artifact', - 'are_modalities_compatible', + 'new_task', + 'new_text_artifact', ] diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py index 63723340..bcf92286 100644 --- a/src/a2a/utils/artifact.py +++ b/src/a2a/utils/artifact.py @@ -1,6 +1,8 @@ import uuid -from a2a.types import Artifact, Part, TextPart +from typing import Any + +from a2a.types import Artifact, DataPart, Part, TextPart def new_artifact( @@ -13,20 +15,26 @@ def new_artifact( description=description, ) + def new_text_artifact( name: str, text: str, description: str = '', ) -> Artifact: return new_artifact( - [Part(root=TextPart(text=text))], name, description, + [Part(root=TextPart(text=text))], + name, + description, ) + def new_data_artifact( name: str, - data: dict[str, any], + data: dict[str, Any], description: str = '', ): return new_artifact( - [Part(root=DataPart(data=data))], name, description, + [Part(root=DataPart(data=data))], + name, + description, ) diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py index 5a198a1b..a984e557 100644 --- a/src/a2a/utils/message.py +++ b/src/a2a/utils/message.py @@ -22,6 +22,7 @@ def new_agent_text_message( contextId=context_id, ) + def new_agent_parts_message( parts: list[Part], context_id: str | None, @@ -35,6 +36,7 @@ def new_agent_parts_message( contextId=context_id, ) + def get_text_parts(parts: list[Part]) -> list[str]: """Return all text parts from a list of parts.""" return [part.root.text for part in parts if isinstance(part.root, TextPart)] diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index ab4eee3b..52f9596a 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -18,8 +18,10 @@ def completed_task( task_id: str, context_id: str, artifacts: list[Artifact], - history: list[Message] = [], + history: list[Message] | None = None, ) -> Task: + if history is None: + history = [] return Task( status=TaskStatus(state=TaskState.completed), id=task_id, diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 4b986288..e556b9c8 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -1,25 +1,25 @@ from typing import Any + import pytest -from unittest.mock import MagicMock -from uuid import uuid4 -from a2a.utils.helpers import ( - create_task_obj, - append_artifact_to_task, - build_text_artifact, - validate, -) + from a2a.types import ( Artifact, - MessageSendParams, Message, + MessageSendParams, + Part, Task, TaskArtifactUpdateEvent, TaskState, - TaskStatus, TextPart, - Part, ) -from a2a.utils.errors import ServerError, UnsupportedOperationError +from a2a.utils.errors import ServerError +from a2a.utils.helpers import ( + append_artifact_to_task, + build_text_artifact, + create_task_obj, + validate, +) + # --- Helper Data --- TEXT_PART_DATA: dict[str, Any] = {'type': 'text', 'text': 'Hello'} @@ -40,6 +40,7 @@ 'type': 'task', } + # Test create_task_obj def test_create_task_obj(): message = Message(**MINIMAL_MESSAGE_USER) @@ -55,7 +56,7 @@ def test_create_task_obj(): # Test append_artifact_to_task def test_append_artifact_to_task(): - # Prepare base task + # Prepare base task task = Task(**MINIMAL_TASK) assert task.id == 'task-abc' assert task.contextId == 'session-xyz' @@ -66,66 +67,86 @@ def test_append_artifact_to_task(): # Prepare appending artifact and event artifact_1 = Artifact( - artifactId="artifact-123", parts=[Part(root=TextPart(text="Hello"))] + artifactId='artifact-123', parts=[Part(root=TextPart(text='Hello'))] + ) + append_event_1 = TaskArtifactUpdateEvent( + artifact=artifact_1, append=False, taskId='123', contextId='123' ) - append_event_1 = TaskArtifactUpdateEvent(artifact=artifact_1, append=False, taskId="123", contextId="123") # Test adding a new artifact (not appending) append_artifact_to_task(task, append_event_1) assert len(task.artifacts) == 1 - assert task.artifacts[0].artifactId == "artifact-123" - assert task.artifacts[0].name == None + assert task.artifacts[0].artifactId == 'artifact-123' + assert task.artifacts[0].name is None assert len(task.artifacts[0].parts) == 1 - assert task.artifacts[0].parts[0].root.text == "Hello" + assert task.artifacts[0].parts[0].root.text == 'Hello' # Test replacing the artifact artifact_2 = Artifact( - artifactId="artifact-123", name="updated name", parts=[Part(root=TextPart(text="Updated"))] + artifactId='artifact-123', + name='updated name', + parts=[Part(root=TextPart(text='Updated'))], + ) + append_event_2 = TaskArtifactUpdateEvent( + artifact=artifact_2, append=False, taskId='123', contextId='123' ) - append_event_2 = TaskArtifactUpdateEvent(artifact=artifact_2, append=False, taskId="123", contextId="123") append_artifact_to_task(task, append_event_2) assert len(task.artifacts) == 1 # Should still have one artifact - assert task.artifacts[0].artifactId == "artifact-123" - assert task.artifacts[0].name == "updated name" + assert task.artifacts[0].artifactId == 'artifact-123' + assert task.artifacts[0].name == 'updated name' assert len(task.artifacts[0].parts) == 1 - assert task.artifacts[0].parts[0].root.text == "Updated" + assert task.artifacts[0].parts[0].root.text == 'Updated' # Test appending parts to an existing artifact artifact_with_parts = Artifact( - artifactId="artifact-123", parts=[Part(root=TextPart(text="Part 2"))] + artifactId='artifact-123', parts=[Part(root=TextPart(text='Part 2'))] + ) + append_event_3 = TaskArtifactUpdateEvent( + artifact=artifact_with_parts, append=True, taskId='123', contextId='123' ) - append_event_3 = TaskArtifactUpdateEvent(artifact=artifact_with_parts, append=True, taskId="123", contextId="123") append_artifact_to_task(task, append_event_3) assert len(task.artifacts[0].parts) == 2 - assert task.artifacts[0].parts[0].root.text == "Updated" - assert task.artifacts[0].parts[1].root.text == "Part 2" + assert task.artifacts[0].parts[0].root.text == 'Updated' + assert task.artifacts[0].parts[1].root.text == 'Part 2' # Test adding another new artifact another_artifact_with_parts = Artifact( - artifactId="new_artifact", parts=[Part(root=TextPart(text="new artifact Part 1"))] + artifactId='new_artifact', + parts=[Part(root=TextPart(text='new artifact Part 1'))], + ) + append_event_4 = TaskArtifactUpdateEvent( + artifact=another_artifact_with_parts, + append=False, + taskId='123', + contextId='123', ) - append_event_4 = TaskArtifactUpdateEvent(artifact=another_artifact_with_parts, append=False, taskId="123", contextId="123") append_artifact_to_task(task, append_event_4) assert len(task.artifacts) == 2 - assert task.artifacts[0].artifactId == "artifact-123" - assert task.artifacts[1].artifactId == "new_artifact" + assert task.artifacts[0].artifactId == 'artifact-123' + assert task.artifacts[1].artifactId == 'new_artifact' assert len(task.artifacts[0].parts) == 2 assert len(task.artifacts[1].parts) == 1 # Test appending part to a task that does not have a matching artifact non_existing_artifact_with_parts = Artifact( - artifactId="artifact-456", parts=[Part(root=TextPart(text="Part 1"))] + artifactId='artifact-456', parts=[Part(root=TextPart(text='Part 1'))] + ) + append_event_5 = TaskArtifactUpdateEvent( + artifact=non_existing_artifact_with_parts, + append=True, + taskId='123', + contextId='123', ) - append_event_5 = TaskArtifactUpdateEvent(artifact=non_existing_artifact_with_parts, append=True, taskId="123", contextId="123") append_artifact_to_task(task, append_event_5) assert len(task.artifacts) == 2 assert len(task.artifacts[0].parts) == 2 assert len(task.artifacts[1].parts) == 1 + # Test build_text_artifact def test_build_text_artifact(): - artifact_id = "text_artifact" - text = "This is a sample text" + artifact_id = 'text_artifact' + text = 'This is a sample text' artifact = build_text_artifact(text, artifact_id) assert artifact.artifactId == artifact_id @@ -138,17 +159,17 @@ def test_validate_decorator(): class TestClass: condition = True - @validate(lambda self: self.condition, "Condition not met") + @validate(lambda self: self.condition, 'Condition not met') def test_method(self): - return "Success" + return 'Success' obj = TestClass() # Test passing condition - assert obj.test_method() == "Success" + assert obj.test_method() == 'Success' # Test failing condition obj.condition = False with pytest.raises(ServerError) as exc_info: obj.test_method() - assert "Condition not met" in str(exc_info.value) + assert 'Condition not met' in str(exc_info.value) diff --git a/tests/utils/test_telemetry.py b/tests/utils/test_telemetry.py index 075cca28..90ea17b0 100644 --- a/tests/utils/test_telemetry.py +++ b/tests/utils/test_telemetry.py @@ -1,13 +1,16 @@ -import pytest -import types import asyncio + from unittest import mock -from a2a.utils.telemetry import trace_function, trace_class + +import pytest + +from a2a.utils.telemetry import trace_class, trace_function + @pytest.fixture def mock_span(): - span = mock.MagicMock() - return span + return mock.MagicMock() + @pytest.fixture def mock_tracer(mock_span): @@ -16,11 +19,13 @@ def mock_tracer(mock_span): tracer.start_as_current_span.return_value.__exit__.return_value = False return tracer + @pytest.fixture(autouse=True) def patch_trace_get_tracer(mock_tracer): - with mock.patch("opentelemetry.trace.get_tracer", return_value=mock_tracer): + with mock.patch('opentelemetry.trace.get_tracer', return_value=mock_tracer): yield + def test_trace_function_sync_success(mock_span): @trace_function def foo(x, y): @@ -32,18 +37,21 @@ def foo(x, y): mock_span.set_status.assert_any_call(mock.ANY) mock_span.record_exception.assert_not_called() + def test_trace_function_sync_exception(mock_span): @trace_function def bar(): - raise ValueError("fail") + raise ValueError('fail') with pytest.raises(ValueError): bar() mock_span.record_exception.assert_called() - mock_span.set_status.assert_any_call(mock.ANY, description="fail") + mock_span.set_status.assert_any_call(mock.ANY, description='fail') + def test_trace_function_sync_attribute_extractor_called(mock_span): called = {} + def attr_extractor(span, args, kwargs, result, exception): called['called'] = True assert span is mock_span @@ -57,10 +65,12 @@ def foo(): foo() assert called['called'] + def test_trace_function_sync_attribute_extractor_error_logged(mock_span): - with mock.patch("a2a.utils.telemetry.logger") as logger: + with mock.patch('a2a.utils.telemetry.logger') as logger: + def attr_extractor(span, args, kwargs, result, exception): - raise RuntimeError("attr fail") + raise RuntimeError('attr fail') @trace_function(attribute_extractor=attr_extractor) def foo(): @@ -69,6 +79,7 @@ def foo(): foo() logger.error.assert_any_call(mock.ANY) + @pytest.mark.asyncio async def test_trace_function_async_success(mock_span): @trace_function @@ -81,21 +92,24 @@ async def foo(x): mock_span.set_status.assert_called() mock_span.record_exception.assert_not_called() + @pytest.mark.asyncio async def test_trace_function_async_exception(mock_span): @trace_function async def bar(): await asyncio.sleep(0) - raise RuntimeError("async fail") + raise RuntimeError('async fail') with pytest.raises(RuntimeError): await bar() mock_span.record_exception.assert_called() - mock_span.set_status.assert_any_call(mock.ANY, description="async fail") + mock_span.set_status.assert_any_call(mock.ANY, description='async fail') + @pytest.mark.asyncio async def test_trace_function_async_attribute_extractor_called(mock_span): called = {} + def attr_extractor(span, args, kwargs, result, exception): called['called'] = True assert exception is None @@ -108,47 +122,62 @@ async def foo(): await foo() assert called['called'] + def test_trace_function_with_args_and_attributes(mock_span): - @trace_function(span_name="custom.span", attributes={"foo": "bar"}) + @trace_function(span_name='custom.span', attributes={'foo': 'bar'}) def foo(): return 1 foo() - mock_span.set_attribute.assert_any_call("foo", "bar") + mock_span.set_attribute.assert_any_call('foo', 'bar') + def test_trace_class_exclude_list(mock_span): - @trace_class(exclude_list=["skip_me"]) + @trace_class(exclude_list=['skip_me']) class MyClass: - def a(self): return "a" - def skip_me(self): return "skip" - def __str__(self): return "str" + def a(self): + return 'a' + + def skip_me(self): + return 'skip' + + def __str__(self): + return 'str' obj = MyClass() - assert obj.a() == "a" - assert obj.skip_me() == "skip" + assert obj.a() == 'a' + assert obj.skip_me() == 'skip' # Only 'a' is traced, not 'skip_me' or dunder - assert hasattr(obj.a, "__wrapped__") - assert not hasattr(obj.skip_me, "__wrapped__") + assert hasattr(obj.a, '__wrapped__') + assert not hasattr(obj.skip_me, '__wrapped__') + def test_trace_class_include_list(mock_span): - @trace_class(include_list=["only_this"]) + @trace_class(include_list=['only_this']) class MyClass: - def only_this(self): return "yes" - def not_this(self): return "no" + def only_this(self): + return 'yes' + + def not_this(self): + return 'no' obj = MyClass() - assert obj.only_this() == "yes" - assert obj.not_this() == "no" - assert hasattr(obj.only_this, "__wrapped__") - assert not hasattr(obj.not_this, "__wrapped__") + assert obj.only_this() == 'yes' + assert obj.not_this() == 'no' + assert hasattr(obj.only_this, '__wrapped__') + assert not hasattr(obj.not_this, '__wrapped__') + def test_trace_class_dunder_not_traced(mock_span): @trace_class() class MyClass: - def __init__(self): self.x = 1 - def foo(self): return "foo" + def __init__(self): + self.x = 1 + + def foo(self): + return 'foo' obj = MyClass() - assert obj.foo() == "foo" - assert hasattr(obj.foo, "__wrapped__") - assert hasattr(obj, "x") \ No newline at end of file + assert obj.foo() == 'foo' + assert hasattr(obj.foo, '__wrapped__') + assert hasattr(obj, 'x')