diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 870c5f8e..a0c81ec8 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -4,6 +4,7 @@ InvalidParamsError, Message, MessageSendParams, + MessageSendConfiguration, Task, ) from a2a.utils import get_message_text @@ -79,6 +80,10 @@ def task_id(self) -> str | None: def context_id(self) -> str | None: return self._context_id + @property + def configuration(self) -> MessageSendConfiguration | None: + return self._params.configuration + def _check_or_generate_task_id(self) -> None: if not self._params: return diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 39751c92..1bc9641c 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -42,6 +42,7 @@ def add_artifact( parts: list[Part], artifact_id=str(uuid.uuid4()), name: str | None = None, + metadata: dict[str, any] | None = None, ): """Add an artifact to the task.""" self.event_queue.enqueue_event( @@ -52,6 +53,7 @@ def add_artifact( artifactId=artifact_id, name=name, parts=parts, + metadata=metadata, ), ) ) @@ -64,6 +66,14 @@ def complete(self, message: Message | None = None): final=True, ) + def failed(self, message: Message | None = None): + """Mark the task as failed.""" + self.update_status( + TaskState.failed, + message=message, + final=True + ) + def submit(self, message: Message | None = None): """Mark the task as submitted.""" self.update_status( diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 42e5d37e..dfafb7b5 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,15 +1,24 @@ -from a2a.utils.artifact import new_text_artifact +from a2a.utils.artifact import ( + new_artifact, + new_data_artifact, + new_text_artifact, +) from a2a.utils.helpers import ( append_artifact_to_task, build_text_artifact, create_task_obj, + are_modalities_compatible, ) from a2a.utils.message import ( get_message_text, get_text_parts, new_agent_text_message, + new_agent_parts_message, +) +from a2a.utils.task import ( + new_task, + completed_task, ) -from a2a.utils.task import new_task __all__ = [ @@ -21,4 +30,9 @@ 'new_agent_text_message', 'new_task', 'new_text_artifact', + 'new_agent_parts_message', + 'completed_task', + 'new_artifact', + 'new_data_artifact', + 'are_modalities_compatible', ] diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py index d282b686..63723340 100644 --- a/src/a2a/utils/artifact.py +++ b/src/a2a/utils/artifact.py @@ -3,14 +3,30 @@ from a2a.types import Artifact, Part, TextPart -def new_text_artifact( - name: str, - text: str, - description: str = '', +def new_artifact( + parts: list[Part], name: str, description: str = '' ) -> Artifact: return Artifact( artifactId=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=text))], + parts=parts, name=name, description=description, ) + +def new_text_artifact( + name: str, + text: str, + description: str = '', +) -> Artifact: + return new_artifact( + [Part(root=TextPart(text=text))], name, description, + ) + +def new_data_artifact( + name: str, + data: dict[str, any], + description: str = '', +): + return new_artifact( + [Part(root=DataPart(data=data))], name, description, + ) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 651c37ea..648dd248 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -104,3 +104,18 @@ def wrapper(self, *args, **kwargs): return wrapper return decorator + + +def are_modalities_compatible( + server_output_modes: list[str], client_output_modes: list[str] +): + """Modalities are compatible if they are both non-empty + and there is at least one common element. + """ + if client_output_modes is None or len(client_output_modes) == 0: + return True + + if server_output_modes is None or len(server_output_modes) == 0: + return True + + return any(x in server_output_modes for x in client_output_modes) diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py index cf153a79..5a198a1b 100644 --- a/src/a2a/utils/message.py +++ b/src/a2a/utils/message.py @@ -9,7 +9,9 @@ def new_agent_text_message( - text: str, context_id: str | None = None, task_id: str | None = None + text: str, + context_id: str | None = None, + task_id: str | None = None, ) -> Message: """Creates a new agent text message.""" return Message( @@ -20,6 +22,18 @@ def new_agent_text_message( contextId=context_id, ) +def new_agent_parts_message( + parts: list[Part], + context_id: str | None, + task_id: str | None = None, +): + return Message( + role=Role.agent, + parts=parts, + messageId=str(uuid.uuid4()), + taskId=task_id, + contextId=context_id, + ) def get_text_parts(parts: list[Part]) -> list[str]: """Return all text parts from a list of parts.""" diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index cd0da7e4..ab4eee3b 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -1,6 +1,6 @@ import uuid -from a2a.types import Message, Task, TaskState, TaskStatus +from a2a.types import Artifact, Message, Task, TaskState, TaskStatus def new_task(request: Message) -> Task: @@ -12,3 +12,18 @@ def new_task(request: Message) -> Task: ), history=[request], ) + + +def completed_task( + task_id: str, + context_id: str, + artifacts: list[Artifact], + history: list[Message] = [], +) -> Task: + return Task( + status=TaskStatus(state=TaskState.completed), + id=task_id, + contextId=context_id, + artifacts=artifacts, + history=history, + )