Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import uuid

from typing import Any

from a2a.server.context import ServerCallContext
from a2a.server.id_generator import (
IDGenerator,
IDGeneratorContext,
UUIDGenerator,
)
from a2a.types import (
InvalidParamsError,
Message,
Expand Down Expand Up @@ -30,6 +33,8 @@ def __init__( # noqa: PLR0913
task: Task | None = None,
related_tasks: list[Task] | None = None,
call_context: ServerCallContext | None = None,
task_id_generator: IDGenerator | None = None,
context_id_generator: IDGenerator | None = None,
):
"""Initializes the RequestContext.

Expand All @@ -40,6 +45,8 @@ def __init__( # noqa: PLR0913
task: The existing `Task` object retrieved from the store, if any.
related_tasks: A list of other tasks related to the current request (e.g., for tool use).
call_context: The server call context associated with this request.
task_id_generator: ID generator for new task IDs. Defaults to UUID generator.
context_id_generator: ID generator for new context IDs. Defaults to UUID generator.
"""
if related_tasks is None:
related_tasks = []
Expand All @@ -49,6 +56,12 @@ def __init__( # noqa: PLR0913
self._current_task = task
self._related_tasks = related_tasks
self._call_context = call_context
self._task_id_generator = (
task_id_generator if task_id_generator else UUIDGenerator()
)
self._context_id_generator = (
context_id_generator if context_id_generator else UUIDGenerator()
)
# If the task id and context id were provided, make sure they
# match the request. Otherwise, create them
if self._params:
Expand Down Expand Up @@ -163,7 +176,9 @@ def _check_or_generate_task_id(self) -> None:
return

if not self._task_id and not self._params.message.task_id:
self._params.message.task_id = str(uuid.uuid4())
self._params.message.task_id = self._task_id_generator.generate(
IDGeneratorContext(context_id=self._context_id)
)
if self._params.message.task_id:
self._task_id = self._params.message.task_id

Expand All @@ -173,6 +188,10 @@ def _check_or_generate_context_id(self) -> None:
return

if not self._context_id and not self._params.message.context_id:
self._params.message.context_id = str(uuid.uuid4())
self._params.message.context_id = (
self._context_id_generator.generate(
IDGeneratorContext(task_id=self._task_id)
)
)
if self._params.message.context_id:
self._context_id = self._params.message.context_id
28 changes: 28 additions & 0 deletions src/a2a/server/id_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import uuid

from abc import ABC, abstractmethod

from pydantic import BaseModel


class IDGeneratorContext(BaseModel):
"""Context for providing additional information to ID generators."""

task_id: str | None = None
context_id: str | None = None


class IDGenerator(ABC):
"""Interface for generating unique identifiers."""

@abstractmethod
def generate(self, context: IDGeneratorContext) -> str:
pass


class UUIDGenerator(IDGenerator):
"""UUID implementation of the IDGenerator interface."""

def generate(self, context: IDGeneratorContext) -> str:
"""Generates a random UUID, ignoring the context."""
return str(uuid.uuid4())
35 changes: 31 additions & 4 deletions src/a2a/server/tasks/task_updater.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import uuid

from datetime import datetime, timezone
from typing import Any

from a2a.server.events import EventQueue
from a2a.server.id_generator import (
IDGenerator,
IDGeneratorContext,
UUIDGenerator,
)
from a2a.types import (
Artifact,
Message,
Expand All @@ -23,13 +27,22 @@ class TaskUpdater:
Simplifies the process of creating and enqueueing standard task events.
"""

def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
def __init__(
self,
event_queue: EventQueue,
task_id: str,
context_id: str,
artifact_id_generator: IDGenerator | None = None,
message_id_generator: IDGenerator | None = None,
):
"""Initializes the TaskUpdater.

Args:
event_queue: The `EventQueue` associated with the task.
task_id: The ID of the task.
context_id: The context ID of the task.
artifact_id_generator: ID generator for new artifact IDs. Defaults to UUID generator.
message_id_generator: ID generator for new message IDs. Defaults to UUID generator.
"""
self.event_queue = event_queue
self.task_id = task_id
Expand All @@ -42,6 +55,12 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
TaskState.failed,
TaskState.rejected,
}
self._artifact_id_generator = (
artifact_id_generator if artifact_id_generator else UUIDGenerator()
)
self._message_id_generator = (
message_id_generator if message_id_generator else UUIDGenerator()
)

async def update_status(
self,
Expand Down Expand Up @@ -110,7 +129,11 @@ async def add_artifact( # noqa: PLR0913
extensions: Optional list of extensions for the artifact.
"""
if not artifact_id:
artifact_id = str(uuid.uuid4())
artifact_id = self._artifact_id_generator.generate(
IDGeneratorContext(
task_id=self.task_id, context_id=self.context_id
)
)

await self.event_queue.enqueue_event(
TaskArtifactUpdateEvent(
Expand Down Expand Up @@ -205,7 +228,11 @@ def new_agent_message(
role=Role.agent,
task_id=self.task_id,
context_id=self.context_id,
message_id=str(uuid.uuid4()),
message_id=self._message_id_generator.generate(
IDGeneratorContext(
task_id=self.task_id, context_id=self.context_id
)
),
metadata=metadata,
parts=parts,
)
29 changes: 29 additions & 0 deletions tests/server/agent_execution/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from a2a.server.agent_execution import RequestContext
from a2a.server.context import ServerCallContext
from a2a.server.id_generator import IDGenerator
from a2a.types import (
Message,
MessageSendParams,
Expand Down Expand Up @@ -149,6 +150,20 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
assert context.task_id == existing_id
assert mock_params.message.task_id == existing_id

def test_check_or_generate_task_id_with_custom_id_generator(
self, mock_params
):
"""Test _check_or_generate_task_id uses custom ID generator when provided."""
id_generator = Mock(spec=IDGenerator)
id_generator.generate.return_value = 'custom-task-id'

context = RequestContext(
request=mock_params, task_id_generator=id_generator
)
# The method is called during initialization

assert context.task_id == 'custom-task-id'

def test_check_or_generate_context_id_no_params(self):
"""Test _check_or_generate_context_id with no params does nothing."""
context = RequestContext()
Expand All @@ -168,6 +183,20 @@ def test_check_or_generate_context_id_with_existing_context_id(
assert context.context_id == existing_id
assert mock_params.message.context_id == existing_id

def test_check_or_generate_context_id_with_custom_id_generator(
self, mock_params
):
"""Test _check_or_generate_context_id uses custom ID generator when provided."""
id_generator = Mock(spec=IDGenerator)
id_generator.generate.return_value = 'custom-context-id'

context = RequestContext(
request=mock_params, context_id_generator=id_generator
)
# The method is called during initialization

assert context.context_id == 'custom-context-id'

def test_init_raises_error_on_task_id_mismatch(
self, mock_params, mock_task
):
Expand Down
39 changes: 38 additions & 1 deletion tests/server/tasks/test_task_updater.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import uuid

from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch

import pytest

from a2a.server.events import EventQueue
from a2a.server.id_generator import IDGenerator
from a2a.server.tasks import TaskUpdater
from a2a.types import (
Message,
Expand Down Expand Up @@ -151,6 +152,26 @@ async def test_add_artifact_generates_id(
assert event.last_chunk is None


@pytest.mark.asyncio
async def test_add_artifact_generates_custom_id(event_queue, sample_parts):
"""Test add_artifact uses a custom ID generator when provided."""
artifact_id_generator = Mock(spec=IDGenerator)
artifact_id_generator.generate.return_value = 'custom-artifact-id'
task_updater = TaskUpdater(
event_queue=event_queue,
task_id='test-task-id',
context_id='test-context-id',
artifact_id_generator=artifact_id_generator,
)

await task_updater.add_artifact(parts=sample_parts, artifact_id=None)

event_queue.enqueue_event.assert_called_once()
event = event_queue.enqueue_event.call_args[0][0]
assert isinstance(event, TaskArtifactUpdateEvent)
assert event.artifact.artifact_id == 'custom-artifact-id'


@pytest.mark.asyncio
@pytest.mark.parametrize(
'append_val, last_chunk_val',
Expand Down Expand Up @@ -304,6 +325,22 @@ def test_new_agent_message_with_metadata(task_updater, sample_parts):
assert message.metadata == metadata


def test_new_agent_message_with_custom_id_generator(event_queue, sample_parts):
"""Test creating a new agent message with a custom message ID generator."""
message_id_generator = Mock(spec=IDGenerator)
message_id_generator.generate.return_value = 'custom-message-id'
task_updater = TaskUpdater(
event_queue=event_queue,
task_id='test-task-id',
context_id='test-context-id',
message_id_generator=message_id_generator,
)

message = task_updater.new_agent_message(parts=sample_parts)

assert message.message_id == 'custom-message-id'


@pytest.mark.asyncio
async def test_failed_without_message(task_updater, event_queue):
"""Test marking a task as failed without a message."""
Expand Down
Loading