Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/a2a/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions for the A2A Python SDK."""

from a2a.utils.artifact import (
get_artifact_text,
new_artifact,
new_data_artifact,
new_text_artifact,
Expand All @@ -18,13 +19,15 @@
create_task_obj,
)
from a2a.utils.message import (
get_data_parts,
get_file_parts,
get_message_text,
get_text_parts,
new_agent_parts_message,
new_agent_text_message,
)
from a2a.utils.parts import (
get_data_parts,
get_file_parts,
get_text_parts,
)
from a2a.utils.task import (
completed_task,
new_task,
Expand All @@ -41,6 +44,7 @@
'build_text_artifact',
'completed_task',
'create_task_obj',
'get_artifact_text',
'get_data_parts',
'get_file_parts',
'get_message_text',
Expand Down
14 changes: 14 additions & 0 deletions src/a2a/utils/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

from a2a.types import Artifact, DataPart, Part, TextPart
from a2a.utils.parts import get_text_parts


def new_artifact(
Expand Down Expand Up @@ -70,3 +71,16 @@ def new_data_artifact(
name,
description,
)


def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str:
"""Extracts and joins all text content from an Artifact's parts.

Args:
artifact: The `Artifact` object.
delimiter: The string to use when joining text from multiple TextParts.

Returns:
A single string containing all text content, or an empty string if no text parts are found.
"""
return delimiter.join(get_text_parts(artifact.parts))
43 changes: 1 addition & 42 deletions src/a2a/utils/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@

import uuid

from typing import Any

from a2a.types import (
DataPart,
FilePart,
FileWithBytes,
FileWithUri,
Message,
Part,
Role,
TextPart,
)
from a2a.utils.parts import get_text_parts


def new_agent_text_message(
Expand Down Expand Up @@ -64,42 +59,6 @@ def new_agent_parts_message(
)


def get_text_parts(parts: list[Part]) -> list[str]:
"""Extracts text content from all TextPart objects in a list of Parts.

Args:
parts: A list of `Part` objects.

Returns:
A list of strings containing the text content from any `TextPart` objects found.
"""
return [part.root.text for part in parts if isinstance(part.root, TextPart)]


def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]:
"""Extracts dictionary data from all DataPart objects in a list of Parts.

Args:
parts: A list of `Part` objects.

Returns:
A list of dictionaries containing the data from any `DataPart` objects found.
"""
return [part.root.data for part in parts if isinstance(part.root, DataPart)]


def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]:
"""Extracts file data from all FilePart objects in a list of Parts.

Args:
parts: A list of `Part` objects.

Returns:
A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found.
"""
return [part.root.file for part in parts if isinstance(part.root, FilePart)]


def get_message_text(message: Message, delimiter: str = '\n') -> str:
"""Extracts and joins all text content from a Message's parts.

Expand Down
48 changes: 48 additions & 0 deletions src/a2a/utils/parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Utility functions for creating and handling A2A Parts objects."""

from typing import Any

from a2a.types import (
DataPart,
FilePart,
FileWithBytes,
FileWithUri,
Part,
TextPart,
)


def get_text_parts(parts: list[Part]) -> list[str]:
"""Extracts text content from all TextPart objects in a list of Parts.

Args:
parts: A list of `Part` objects.

Returns:
A list of strings containing the text content from any `TextPart` objects found.
"""
return [part.root.text for part in parts if isinstance(part.root, TextPart)]


def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]:
"""Extracts dictionary data from all DataPart objects in a list of Parts.

Args:
parts: A list of `Part` objects.

Returns:
A list of dictionaries containing the data from any `DataPart` objects found.
"""
return [part.root.data for part in parts if isinstance(part.root, DataPart)]


def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]:
"""Extracts file data from all FilePart objects in a list of Parts.

Args:
parts: A list of `Part` objects.

Returns:
A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found.
"""
return [part.root.file for part in parts if isinstance(part.root, FilePart)]
74 changes: 73 additions & 1 deletion tests/utils/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

from unittest.mock import patch

from a2a.types import DataPart, Part, TextPart
from a2a.types import (
Artifact,
DataPart,
Part,
TextPart,
)
from a2a.utils.artifact import (
get_artifact_text,
new_artifact,
new_data_artifact,
new_text_artifact,
Expand Down Expand Up @@ -83,5 +89,71 @@ def test_new_data_artifact_assigns_name_description(self):
self.assertEqual(artifact.description, description)


class TestGetArtifactText(unittest.TestCase):
def test_get_artifact_text_single_part(self):
# Setup
artifact = Artifact(
name='test-artifact',
parts=[Part(root=TextPart(text='Hello world'))],
artifact_id='test-artifact-id',
)

# Exercise
result = get_artifact_text(artifact)

# Verify
assert result == 'Hello world'

def test_get_artifact_text_multiple_parts(self):
# Setup
artifact = Artifact(
name='test-artifact',
parts=[
Part(root=TextPart(text='First line')),
Part(root=TextPart(text='Second line')),
Part(root=TextPart(text='Third line')),
],
artifact_id='test-artifact-id',
)

# Exercise
result = get_artifact_text(artifact)

# Verify - default delimiter is newline
assert result == 'First line\nSecond line\nThird line'

def test_get_artifact_text_custom_delimiter(self):
# Setup
artifact = Artifact(
name='test-artifact',
parts=[
Part(root=TextPart(text='First part')),
Part(root=TextPart(text='Second part')),
Part(root=TextPart(text='Third part')),
],
artifact_id='test-artifact-id',
)

# Exercise
result = get_artifact_text(artifact, delimiter=' | ')

# Verify
assert result == 'First part | Second part | Third part'

def test_get_artifact_text_empty_parts(self):
# Setup
artifact = Artifact(
name='test-artifact',
parts=[],
artifact_id='test-artifact-id',
)

# Exercise
result = get_artifact_text(artifact)

# Verify
assert result == ''


if __name__ == '__main__':
unittest.main()
Loading
Loading