diff --git a/.gitignore b/.gitignore index 4da52568..6252577e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ __pycache__ .venv coverage.xml .nox +spec.json \ No newline at end of file diff --git a/development.md b/development.md new file mode 100644 index 00000000..0d9ef29c --- /dev/null +++ b/development.md @@ -0,0 +1,9 @@ +# Development + +## Type generation from spec + + + +```bash +uv run datamodel-codegen --input ./spec.json --input-file-type jsonschema --output ./src/a2a/types.py --target-python-version 3.10 --output-model-type pydantic_v2.BaseModel --disable-timestamp --use-schema-description --use-union-operator --use-field-description --use-default --use-default-kwarg --use-one-literal-as-default --class-name A2A --use-standard-collections +``` diff --git a/examples/helloworld/test_client.py b/examples/helloworld/test_client.py index 0ade8d2a..561784ad 100644 --- a/examples/helloworld/test_client.py +++ b/examples/helloworld/test_client.py @@ -18,7 +18,7 @@ async def main() -> None: 'message': { 'role': 'user', 'parts': [ - {'type': 'text', 'text': 'how much is 10 USD in INR?'} + {'kind': 'text', 'text': 'how much is 10 USD in INR?'} ], 'messageId': uuid4().hex, }, diff --git a/examples/langgraph/test_client.py b/examples/langgraph/test_client.py index 1eb2664a..3b1c1d03 100644 --- a/examples/langgraph/test_client.py +++ b/examples/langgraph/test_client.py @@ -26,7 +26,7 @@ def create_send_message_payload( payload: dict[str, Any] = { 'message': { 'role': 'user', - 'parts': [{'type': 'text', 'text': text}], + 'parts': [{'kind': 'text', 'text': text}], 'messageId': uuid4().hex, }, } diff --git a/src/a2a/types.py b/src/a2a/types.py index e7dc4c91..3587d441 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -133,13 +133,13 @@ class DataPart(BaseModel): """ Structured data content """ - metadata: dict[str, Any] | None = None + kind: Literal['data'] = 'data' """ - Optional metadata associated with the part. + Part type - data for DataParts """ - type: Literal['data'] = 'data' + metadata: dict[str, Any] | None = None """ - Part type - data for DataParts + Optional metadata associated with the part. """ @@ -577,11 +577,11 @@ class TaskState(Enum): submitted = 'submitted' working = 'working' input_required = 'input-required' - auth_required = 'auth-required' completed = 'completed' canceled = 'canceled' failed = 'failed' rejected = 'rejected' + auth_required = 'auth-required' unknown = 'unknown' @@ -590,6 +590,10 @@ class TextPart(BaseModel): Represents a text segment within parts. """ + kind: Literal['text'] = 'text' + """ + Part type - text for TextParts + """ metadata: dict[str, Any] | None = None """ Optional metadata associated with the part. @@ -598,10 +602,6 @@ class TextPart(BaseModel): """ Text content """ - type: Literal['text'] = 'text' - """ - Part type - text for TextParts - """ class UnsupportedOperationError(BaseModel): @@ -745,13 +745,13 @@ class FilePart(BaseModel): """ File content either as url or bytes """ - metadata: dict[str, Any] | None = None + kind: Literal['file'] = 'file' """ - Optional metadata associated with the part. + Part type - file for FileParts """ - type: Literal['file'] = 'file' + metadata: dict[str, Any] | None = None """ - Part type - file for FileParts + Optional metadata associated with the part. """ @@ -933,7 +933,7 @@ class SetTaskPushNotificationConfigSuccessResponse(BaseModel): class Artifact(BaseModel): """ - Represents an artifact generated for a task. + Represents an artifact generated for a task task. """ artifactId: str @@ -959,9 +959,7 @@ class Artifact(BaseModel): class GetTaskPushNotificationConfigResponse( - RootModel[ - JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse - ] + RootModel[JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse] ): root: JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse """ @@ -978,9 +976,9 @@ class Message(BaseModel): """ the context the message is associated with """ - final: bool | None = None + kind: Literal['message'] = 'message' """ - indicates the end of the event stream + event type """ messageId: str """ @@ -1002,10 +1000,6 @@ class Message(BaseModel): """ identifier of task the message is related to """ - type: Literal['message'] = 'message' - """ - event type - """ class MessageSendParams(BaseModel): @@ -1076,9 +1070,7 @@ class SendStreamingMessageRequest(BaseModel): class SetTaskPushNotificationConfigResponse( - RootModel[ - JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse - ] + RootModel[JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse] ): root: JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse """ @@ -1103,6 +1095,10 @@ class TaskArtifactUpdateEvent(BaseModel): """ the context the task is associated with """ + kind: Literal['artifact-update'] = 'artifact-update' + """ + event type + """ lastChunk: bool | None = None """ Indicates if this is the last chunk of the artifact @@ -1115,10 +1111,6 @@ class TaskArtifactUpdateEvent(BaseModel): """ Task id """ - type: Literal['artifact-update'] = 'artifact-update' - """ - event type - """ class TaskStatus(BaseModel): @@ -1150,6 +1142,10 @@ class TaskStatusUpdateEvent(BaseModel): """ indicates the end of the event stream """ + kind: Literal['status-update'] = 'status-update' + """ + event type + """ metadata: dict[str, Any] | None = None """ extension metadata @@ -1162,10 +1158,6 @@ class TaskStatusUpdateEvent(BaseModel): """ Task id """ - type: Literal['status-update'] = 'status-update' - """ - event type - """ class A2ARequest( @@ -1207,6 +1199,10 @@ class Task(BaseModel): """ unique identifier for the task """ + kind: Literal['task'] = 'task' + """ + event type + """ metadata: dict[str, Any] | None = None """ extension metadata @@ -1215,10 +1211,6 @@ class Task(BaseModel): """ current status of the task """ - type: Literal['task'] = 'task' - """ - event type - """ class CancelTaskSuccessResponse(BaseModel): @@ -1301,9 +1293,7 @@ class SendStreamingMessageSuccessResponse(BaseModel): """ -class CancelTaskResponse( - RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse] -): +class CancelTaskResponse(RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse]): root: JSONRPCErrorResponse | CancelTaskSuccessResponse """ JSON-RPC response for the 'tasks/cancel' method. @@ -1342,9 +1332,7 @@ class JSONRPCResponse( """ -class SendMessageResponse( - RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse] -): +class SendMessageResponse(RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse]): root: JSONRPCErrorResponse | SendMessageSuccessResponse """ JSON-RPC response model for the 'message/send' method. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index aedc8f10..efb6ca12 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -63,14 +63,14 @@ 'id': 'task-abc', 'contextId': 'session-xyz', 'status': {'state': 'working'}, - 'type': 'task', + 'kind': 'task', } MINIMAL_CANCELLED_TASK: dict[str, Any] = { 'id': 'task-abc', 'contextId': 'session-xyz', 'status': {'state': 'canceled'}, - 'type': 'task', + 'kind': 'task', } diff --git a/tests/server/events/test_event_consumer.py b/tests/server/events/test_event_consumer.py index 56db30a1..08111a2b 100644 --- a/tests/server/events/test_event_consumer.py +++ b/tests/server/events/test_event_consumer.py @@ -28,7 +28,7 @@ 'id': '123', 'contextId': 'session-xyz', 'status': {'state': 'submitted'}, - 'type': 'task', + 'kind': 'task', } MESSAGE_PAYLOAD: dict[str, Any] = { diff --git a/tests/server/events/test_event_queue.py b/tests/server/events/test_event_queue.py index 5b440fa1..8a9c163e 100644 --- a/tests/server/events/test_event_queue.py +++ b/tests/server/events/test_event_queue.py @@ -21,7 +21,7 @@ 'id': '123', 'contextId': 'session-xyz', 'status': {'state': 'submitted'}, - 'type': 'task', + 'kind': 'task', } MESSAGE_PAYLOAD: dict[str, Any] = { 'role': 'agent', diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index d5b94605..cf3e66c8 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -62,7 +62,7 @@ 'id': 'task_123', 'contextId': 'session-xyz', 'status': {'state': 'submitted'}, - 'type': 'task', + 'kind': 'task', } MESSAGE_PAYLOAD: dict[str, Any] = { 'role': 'agent', diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index 9eb8eae9..f5d9df1d 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -10,7 +10,7 @@ 'id': 'task-abc', 'contextId': 'session-xyz', 'status': {'state': 'submitted'}, - 'type': 'task', + 'kind': 'task', } diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index aae189fd..56205fab 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -22,7 +22,7 @@ 'id': 'task-abc', 'contextId': 'session-xyz', 'status': {'state': 'submitted'}, - 'type': 'task', + 'kind': 'task', } @@ -205,7 +205,7 @@ async def test_save_task_event_new_task_no_task_id( 'id': 'new-task-id', 'contextId': 'some-context', 'status': {'state': 'working'}, - 'type': 'task', + 'kind': 'task', } task = Task(**task_data) await task_manager_without_id.save_task_event(task) diff --git a/tests/test_types.py b/tests/test_types.py index 5a01eea0..ef658a19 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -106,22 +106,22 @@ 'version': '1.0', } -TEXT_PART_DATA: dict[str, Any] = {'type': 'text', 'text': 'Hello'} +TEXT_PART_DATA: dict[str, Any] = {'kind': 'text', 'text': 'Hello'} FILE_URI_PART_DATA: dict[str, Any] = { - 'type': 'file', + 'kind': 'file', 'file': {'uri': 'file:///path/to/file.txt', 'mimeType': 'text/plain'}, } FILE_BYTES_PART_DATA: dict[str, Any] = { - 'type': 'file', + 'kind': 'file', 'file': {'bytes': 'aGVsbG8=', 'name': 'hello.txt'}, # base64 for "hello" } -DATA_PART_DATA: dict[str, Any] = {'type': 'data', 'data': {'key': 'value'}} +DATA_PART_DATA: dict[str, Any] = {'kind': 'data', 'data': {'key': 'value'}} MINIMAL_MESSAGE_USER: dict[str, Any] = { 'role': 'user', 'parts': [TEXT_PART_DATA], 'messageId': 'msg-123', - 'type': 'message', + 'kind': 'message', } AGENT_MESSAGE_WITH_FILE: dict[str, Any] = { @@ -142,7 +142,7 @@ 'id': 'task-abc', 'contextId': 'session-xyz', 'status': MINIMAL_TASK_STATUS, - 'type': 'task', + 'kind': 'task', } FULL_TASK: dict[str, Any] = { 'id': 'task-abc', @@ -157,7 +157,7 @@ } ], 'metadata': {'priority': 'high'}, - 'type': 'task', + 'kind': 'task', } MINIMAL_TASK_ID_PARAMS: dict[str, Any] = {'id': 'task-123'} @@ -269,7 +269,7 @@ def test_agent_card_invalid(): def test_text_part(): part = TextPart(**TEXT_PART_DATA) - assert part.type == 'text' + assert part.kind == 'text' assert part.text == 'Hello' assert part.metadata is None @@ -277,7 +277,7 @@ def test_text_part(): TextPart(type='text') # Missing text # type: ignore with pytest.raises(ValidationError): TextPart( - type='file', # type: ignore + kind='file', # type: ignore text='hello', ) # Wrong type literal @@ -287,7 +287,7 @@ def test_file_part_variants(): file_uri = FileWithUri( uri='file:///path/to/file.txt', mimeType='text/plain' ) - part_uri = FilePart(type='file', file=file_uri) + part_uri = FilePart(kind='file', file=file_uri) assert isinstance(part_uri.file, FileWithUri) assert part_uri.file.uri == 'file:///path/to/file.txt' assert part_uri.file.mimeType == 'text/plain' @@ -295,7 +295,7 @@ def test_file_part_variants(): # Bytes variant file_bytes = FileWithBytes(bytes='aGVsbG8=', name='hello.txt') - part_bytes = FilePart(type='file', file=file_bytes) + part_bytes = FilePart(kind='file', file=file_bytes) assert isinstance(part_bytes.file, FileWithBytes) assert part_bytes.file.bytes == 'aGVsbG8=' assert part_bytes.file.name == 'hello.txt' @@ -312,14 +312,14 @@ def test_file_part_variants(): # Invalid - wrong type literal with pytest.raises(ValidationError): - FilePart(type='text', file=file_uri) # type: ignore + FilePart(kind='text', file=file_uri) # type: ignore FilePart(**FILE_URI_PART_DATA, extra='extra') # type: ignore def test_data_part(): part = DataPart(**DATA_PART_DATA) - assert part.type == 'data' + assert part.kind == 'data' assert part.data == {'key': 'value'} with pytest.raises(ValidationError): @@ -656,7 +656,7 @@ def test_send_message_streaming_status_update_response() -> None: 'taskId': '1', 'contextId': '2', 'final': False, - 'type': 'status-update', + 'kind': 'status-update', } event_data: dict[str, Any] = { @@ -716,7 +716,7 @@ def test_send_message_streaming_artifact_update_response() -> None: 'contextId': '2', 'append': False, 'lastChunk': True, - 'type': 'artifact-update', + 'kind': 'artifact-update', } event_data: dict[str, Any] = { 'jsonrpc': '2.0',