diff --git a/README.md b/README.md index e9ef85718..060c3e790 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ As AI agents become more prevalent, their ability to interoperate is crucial for - [Marvin](/samples/python/agents/marvin/README.md) - [Semantic Kernel](/samples/python/agents/semantickernel/README.md) - [AG2 + MCP](/samples/python/agents/ag2/README.md) + - [Restate](/samples/python/agents/restate/README.md) ## Contributing diff --git a/samples/python/agents/restate/README.md b/samples/python/agents/restate/README.md new file mode 100644 index 000000000..3a43c6575 --- /dev/null +++ b/samples/python/agents/restate/README.md @@ -0,0 +1,83 @@ +## Resilient Agents with Restate + +This sample uses [Restate](https://ai.restate.dev/) and the Agent Development Kit (ADK) to create a resilient "Expense Reimbursement" agent that is hosted as an A2A server. + +Restate lets you build resilient applications easily. It provides a distributed durable version of your everyday building blocks. + +In this example, Restate acts as a scalable, resilient task orchestrator that speaks the A2A protocol and gives you: +- 🔁 **Automatic retries** - Handles LLM API downtime, timeouts, and infrastructure failures +- 🔄 **Smart recovery** - Preserves progress across failures without duplicating work +- ⏱️ **Persistent task handles** - Tracks progress across failures, time, and processes +- 🎮 **Task control** - Cancel tasks, query status, re-subscribe to ongoing tasks +- 🧠 **Idempotent submission** - Automatic deduplication based on task ID +- 🤖 **Agentic workflows** - Build resilient agents with human-in-the-loop and parallel tool execution +- 💾 **Durable state** - Maintain consistent agent state across infrastructure events +- 👀 **Full observability** - Line-by-line execution tracking with built-in audit trail +- ☁️️ **Easy to self-host** - or connect to Restate Cloud + + + +This agent takes text requests from the client and, if any details are missing, returns a webform for the client (or its user) to fill out. +After the client fills out the form, the agent will complete the task. + +## Prerequisites + +- Python 3.12 or higher +- [UV](https://docs.astral.sh/uv/) +- Access to an LLM and API Key + + +## Running the Sample + +1. Navigate to the samples directory: + ```bash + cd samples/python/agents/restate + ``` +2. Create an environment file with your API key: + + ```bash + echo "GOOGLE_API_KEY=your_api_key_here" > .env + ``` + +4. Run the A2A server and agent: + ```bash + uv run . + ``` + +6. Start the Restate Server with Docker ([for other options check the docs](https://docs.restate.dev/develop/local_dev#running-restate-server--cli-locally)). + + ```shell + docker run -p 8080:8080 -p 9070:9070 \ + --add-host=host.docker.internal:host-gateway \ + -e 'RESTATE_WORKER__INVOKER__INACTIVITY_TIMEOUT=5min' \ + -e 'RESTATE_WORKER__INVOKER__ABORT_TIMEOUT=5min' \ + docker.restate.dev/restatedev/restate:latest + ``` + + Let Restate know where the A2A server is running: + ```shell + docker run -it --network=host docker.restate.dev/restatedev/restate-cli:latest \ + deployments register http://host.docker.internal:9080/restate/v1 + ``` + +5. In a separate terminal, run the A2A client: + ``` + # Connect to the agent (specify the agent URL with correct port) + cd samples/python/hosts/cli + uv run . --agent http://localhost:9080 + + # If you changed the port when starting the agent, use that port instead + # uv run . --agent http://localhost:YOUR_PORT + ``` + +6. Send requests with the A2A client like: `Reimburse my flight of 700 USD` + +Open the Restate UI ([http://localhost:9070](http://localhost:9070)) to see the task execution log and the task state. + +Example of Restate journal view +Example of Restate state view + +# Learn more +- [Restate Website](https://restate.dev/) +- [Restate Documentation](https://docs.restate.dev/) +- [Restate GitHub repo](https://github.com/restatedev/restate) \ No newline at end of file diff --git a/samples/python/agents/restate/__main__.py b/samples/python/agents/restate/__main__.py new file mode 100644 index 000000000..72f107b33 --- /dev/null +++ b/samples/python/agents/restate/__main__.py @@ -0,0 +1,75 @@ +"""A an example of serving a resilient agent using restate.dev""" + +import os + +import restate + +from agent import ReimbursementAgent +from common.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + MissingAPIKeyError, +) +from dotenv import load_dotenv +from fastapi import FastAPI +from middleware import AgentMiddleware + + +load_dotenv() + +RESTATE_HOST = os.getenv('RESTATE_HOST', 'http://localhost:8080') + +AGENT_CARD = AgentCard( + name='ReimbursementAgent', + description='This agent handles the reimbursement process for the employees given the amount and purpose of the reimbursement.', + url=RESTATE_HOST, + version='1.0.0', + defaultInputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES, + defaultOutputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES, + capabilities=AgentCapabilities(streaming=False), + skills=[ + AgentSkill( + id='process_reimbursement', + name='Process Reimbursement Tool', + description='Helps with the reimbursement process for users given the amount and purpose of the reimbursement.', + tags=['reimbursement'], + examples=[ + 'Can you reimburse me $20 for my lunch with the clients?' + ], + ) + ], +) + +REIMBURSEMENT_AGENT = AgentMiddleware(AGENT_CARD, ReimbursementAgent()) + +app = FastAPI() + + +@app.get('/.well-known/agent.json') +async def agent_json(): + """Serve the agent card""" + return REIMBURSEMENT_AGENT.agent_card_json + + +app.mount('/restate/v1', restate.app(REIMBURSEMENT_AGENT)) + + +def main(): + """Serve the agent at a specified port using hypercorn.""" + import asyncio + + import hypercorn + import hypercorn.asyncio + + if not os.getenv('GOOGLE_API_KEY'): + raise MissingAPIKeyError('GOOGLE_API_KEY environment variable not set.') + + port = os.getenv('AGENT_PORT', '9080') + conf = hypercorn.Config() + conf.bind = [f'0.0.0.0:{port}'] + asyncio.run(hypercorn.asyncio.serve(app, conf)) + + +if __name__ == '__main__': + main() diff --git a/samples/python/agents/restate/agent.py b/samples/python/agents/restate/agent.py new file mode 100644 index 000000000..2b20ad189 --- /dev/null +++ b/samples/python/agents/restate/agent.py @@ -0,0 +1,228 @@ +"""An agent that handles reimbursement requests. Pretty much a copy of the +reimbursement agent from this repo, just made the tools a bit more interesting. +""" + +import json +import logging +import random + +from typing import Any, Optional + +from agents.restate.middleware import AgentInvokeResult +from common.types import TextPart +from google.adk.agents.llm_agent import LlmAgent +from google.adk.artifacts import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +logger = logging.getLogger(__name__) + + +# Local cache of created request_ids for demo purposes. +request_ids = set() + + +def create_request_form( + date: Optional[str] = None, + amount: Optional[str] = None, + purpose: Optional[str] = None, +) -> dict[str, Any]: + """Create a request form for the employee to fill out. + + Args: + date (str): The date of the request. Can be an empty string. + amount (str): The requested amount. Can be an empty string. + purpose (str): The purpose of the request. Can be an empty string. + + Returns: + dict[str, Any]: A dictionary containing the request form data. + """ + logger.info('Creating reimbursement request') + request_id = 'request_id_' + str(random.randint(1000000, 9999999)) + request_ids.add(request_id) + reimbursement = { + 'request_id': request_id, + 'date': '' if not date else date, + 'amount': '' if not amount else amount, + 'purpose': ( + '' + if not purpose + else purpose + ), + } + logger.info('Reimbursement request created: %s', json.dumps(reimbursement)) + + return reimbursement + + +def return_form( + form_request: dict[str, Any], + tool_context: ToolContext, + instructions: Optional[str] = None, +) -> dict[str, Any]: + """Returns a structured json object indicating a form to complete. + + Args: + form_request (dict[str, Any]): The request form data. + tool_context (ToolContext): The context in which the tool operates. + instructions (str): Instructions for processing the form. Can be an empty string. + + Returns: + dict[str, Any]: A JSON dictionary for the form response. + """ + logger.info('Creating return form') + if isinstance(form_request, str): + form_request = json.loads(form_request) + + form_dict = { + 'type': 'form', + 'form': { + 'type': 'object', + 'properties': { + 'date': { + 'type': 'string', + 'format': 'date', + 'description': 'Date of expense', + 'title': 'Date', + }, + 'amount': { + 'type': 'string', + 'format': 'number', + 'description': 'Amount of expense', + 'title': 'Amount', + }, + 'purpose': { + 'type': 'string', + 'description': 'Purpose of expense', + 'title': 'Purpose', + }, + 'request_id': { + 'type': 'string', + 'description': 'Request id', + 'title': 'Request ID', + }, + }, + 'required': list(form_request.keys()), + }, + 'form_data': form_request, + 'instructions': instructions, + } + logger.info('Return form created: %s', json.dumps(form_dict)) + return json.dumps(form_dict) + + +async def reimburse(request_id: str) -> dict[str, Any]: + """Reimburse the amount of money to the employee for a given request_id.""" + logger.info('Starting reimbursement: %s', request_id) + if request_id not in request_ids: + return { + 'request_id': request_id, + 'status': 'Error: Invalid request_id.', + } + logger.info('Reimbursement approved: %s', request_id) + return {'request_id': request_id, 'status': 'approved'} + + +class ReimbursementAgent: + """An agent that handles reimbursement requests.""" + + SUPPORTED_CONTENT_TYPES = ['text', 'text/plain'] + + def __init__(self): + self._agent = self._build_agent() + self._user_id = 'remote_agent' + self._runner = Runner( + app_name=self._agent.name, + agent=self._agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + + async def invoke(self, query, session_id) -> AgentInvokeResult: + logger.info('Invoking LLM') + session = self._runner.session_service.get_session( + app_name=self._agent.name, + user_id=self._user_id, + session_id=session_id, + ) + content = types.Content( + role='user', parts=[types.Part.from_text(text=query)] + ) + if session is None: + self._runner.session_service.create_session( + app_name=self._agent.name, + user_id=self._user_id, + state={}, + session_id=session_id, + ) + + events = [] + async for event in self._runner.run_async( + user_id=self._user_id, + session_id=session_id, + new_message=content, + ): + events.append(event) + + logger.info('LLM response: %s', events) + if not events or not events[-1].content or not events[-1].content.parts: + return AgentInvokeResult( + parts=[TextPart(text='')], + require_user_input=False, + is_task_complete=True, + ) + return AgentInvokeResult( + parts=[ + TextPart( + text='\n'.join( + [p.text for p in events[-1].content.parts if p.text] + ) + ) + ], + require_user_input=False, + is_task_complete=True, + ) + + def _build_agent(self) -> LlmAgent: + """Builds the LLM agent for the reimbursement agent.""" + return LlmAgent( + model='gemini-2.0-flash-001', + name='reimbursement_agent', + description=( + 'This agent handles the reimbursement process for the employees' + ' given the amount and purpose of the reimbursement.' + ), + instruction=""" + You are an agent who handle the reimbursement process for employees. + + When you receive an reimbursement request, you should first create a new request form using create_request_form(). Only provide default values if they are provided by the user, otherwise use an empty string as the default value. + 1. 'Date': the date of the transaction. + 2. 'Amount': the dollar amount of the transaction. + 3. 'Business Justification/Purpose': the reason for the reimbursement. + + Once you created the form, you should return the result of calling return_form with the form data from the create_request_form call. + Clearly let the user know which fields are required and missing. + + Once you received the filled-out form back from the user, you should then check the form contains all required information: + 1. 'Date': the date of the transaction. + 2. 'Amount': the value of the amount of the reimbursement being requested. + 3. 'Business Justification/Purpose': the item/object/artifact of the reimbursement. + + If you don't have all of the information, you should reject the request directly by calling the request_form method, providing the missing fields. + + + For valid reimbursement requests, you can then use reimburse() to reimburse the employee. + * In your response, you should include the request_id and the status of the reimbursement request. + + """, + tools=[ + create_request_form, + reimburse, + return_form, + ], + ) diff --git a/samples/python/agents/restate/middleware.py b/samples/python/agents/restate/middleware.py new file mode 100644 index 000000000..9a70f69ab --- /dev/null +++ b/samples/python/agents/restate/middleware.py @@ -0,0 +1,457 @@ +# pylint: disable=C0116 +import logging +import uuid + +from collections.abc import AsyncIterable, Iterable +from datetime import datetime + +import restate + +from common.types import ( + A2ARequest, + AgentCard, + Artifact, + CancelTaskRequest, + CancelTaskResponse, + GetTaskPushNotificationRequest, + GetTaskPushNotificationResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Message, + Part, + SendTaskRequest, + SendTaskResponse, + SendTaskStreamingRequest, + SendTaskStreamingResponse, + SetTaskPushNotificationRequest, + SetTaskPushNotificationResponse, + Task, + TaskIdParams, + TaskNotFoundError, + TaskQueryParams, + TaskResubscriptionRequest, + TaskSendParams, + TaskState, + TaskStatus, + TextPart, +) +from pydantic import BaseModel +from restate.serde import PydanticJsonSerde + + +logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] [%(process)d] [%(levelname)s] - %(message)s', +) +logger = logging.getLogger(__name__) + + +# MODELS + + +class AgentInvokeResult(BaseModel): + """Result of the agent invocation.""" + + parts: list[Part] + require_user_input: bool + is_task_complete: bool + + +# K/V stored in Restate +TASK = 'task' +INVOCATION_ID = 'invocation-id' + + +class AgentMiddleware(Iterable[restate.Service | restate.VirtualObject]): + """Middleware for the agent to handle task processing and state management.""" + + def __init__(self, agent_card: AgentCard, agent): + self.agent_card = agent_card.model_copy() + self.agent = agent + self.a2a_server_name = f'{self.agent_card.name}A2AServer' + self.task_object_name = f'{self.agent_card.name}TaskObject' + + # replace the base url with the exact url of the process_request handler. + restate_base_url = self.agent_card.url + process_request_url = ( + f'{restate_base_url}/{self.a2a_server_name}/process_request' + ) + self.agent_card.url = process_request_url + + self.restate_services = [] + _build_services(self) + + def __iter__(self): + """Returns the services that define the agent's a2a server and task object.""" + return iter(self.restate_services) + + @property + def agent_card_json(self): + """Return the agent card""" + return self.agent_card.model_dump() + + @property + def services(self) -> Iterable[restate.Service | restate.VirtualObject]: + """Return the services that define the agent's a2a server and task object""" + return self.restate_services + + +def _build_services(middleware: AgentMiddleware): + """Creates an A2A server for reimbursement processing with customizable name and description.""" + a2a_service = restate.Service( + middleware.a2a_server_name, + description=middleware.agent_card.description, + metadata={ + 'agent': middleware.agent_card.name, + 'version': middleware.agent_card.version, + }, + ) + middleware.restate_services.append(a2a_service) + + task_object = restate.VirtualObject(middleware.task_object_name) + middleware.restate_services.append(task_object) + + agent = middleware.agent + + class TaskObject: + """TaskObject is a virtual object that handles task processing and state management.""" + + @staticmethod + @task_object.handler(kind='shared') + async def get_invocation_id( + ctx: restate.ObjectSharedContext, + ) -> str | None: + task_id = ctx.key() + logger.info('Getting invocation id for task %s', task_id) + return await ctx.get(INVOCATION_ID) or None + + @staticmethod + @task_object.handler( + output_serde=PydanticJsonSerde(Task), kind='shared' + ) + async def get_task( + ctx: restate.ObjectSharedContext, + ) -> Task | None: + task_id = ctx.key() + logger.info('Getting task %s', task_id) + return await ctx.get(TASK, type_hint=Task) or None + + @staticmethod + @task_object.handler() + async def cancel_task( + ctx: restate.ObjectContext, request: CancelTaskRequest + ) -> CancelTaskResponse: + cancelled_task = await TaskObject.update_store( + ctx, state=TaskState.CANCELED + ) + return CancelTaskResponse(id=request.id, result=cancelled_task) + + @staticmethod + @task_object.handler() + async def handle_send_task_request( + ctx: restate.ObjectContext, request: SendTaskRequest + ) -> SendTaskResponse: + logger.info( + 'Starting task execution workflow %s for task %s', + request.id, + request.params.id, + ) + + task_send_params: TaskSendParams = request.params + if not task_send_params.sessionId: + session_id = await ctx.run( + 'Generate session id', lambda: str(uuid.uuid4().hex) + ) + task_send_params.sessionId = session_id + + # Store this invocation ID so it can be cancelled by someone else + await TaskObject.set_invocation_id(ctx, ctx.request().id) + + # Persist the request data + await TaskObject.upsert_task(ctx, task_send_params) + + try: + # Forward the request to the agent + result = await ctx.run( + 'Agent invoke', + agent.invoke, + args=( + _get_user_query(task_send_params), + task_send_params.sessionId, + ), + type_hint=AgentInvokeResult, + ) + + if result.require_user_input: + updated_task = await TaskObject.update_store( + ctx, + state=TaskState.INPUT_REQUIRED, + status_message=Message( + role='agent', parts=result.parts + ), + ) + else: + updated_task = await TaskObject.update_store( + ctx, + state=TaskState.COMPLETED, + artifacts=[Artifact(parts=result.parts)], + ) + + ctx.clear(INVOCATION_ID) + return SendTaskResponse(id=request.id, result=updated_task) + except restate.exceptions.TerminalError as e: + if e.status_code == 409 and e.message == 'cancelled': + logger.info('Task %s was cancelled', task_send_params.id) + cancelled_task = await TaskObject.update_store( + ctx, state=TaskState.CANCELED + ) + ctx.clear(INVOCATION_ID) + return SendTaskResponse( + id=request.id, result=cancelled_task + ) + + logger.error( + 'Error while processing task %s: %s', task_send_params.id, e + ) + failed_task = await TaskObject.update_store( + ctx, state=TaskState.FAILED + ) + ctx.clear(INVOCATION_ID) + return SendTaskResponse(id=request.id, result=failed_task) + + @staticmethod + async def update_store( + ctx: restate.ObjectContext, + state: TaskState | None, + status_message: Message | None = None, + artifacts: list[Artifact] | None = None, + ) -> Task: + task_id = ctx.key() + logger.info('Updating status task %s to %s', task_id, state) + + task = await ctx.get(TASK, type_hint=Task) + if task is None: + logger.error('Task %s not found for updating the task', task_id) + raise restate.exceptions.TerminalError( + f'Task {task_id} not found' + ) + + new_task_status = await ctx.run( + 'task status', + lambda task_state=state: TaskStatus( + state=task_state, + timestamp=datetime.now(), + message=status_message, + ), + type_hint=TaskStatus, + ) + prev_status = task.status + if prev_status.message is not None: + task.history.append(prev_status.message) + task.status = new_task_status + + if artifacts is not None: + if task.artifacts is None: + task.artifacts = [] + task.artifacts.extend(artifacts) + + ctx.set(TASK, task) + return task + + @staticmethod + async def set_invocation_id( + ctx: restate.ObjectContext, invocation_id: str + ): + task_id = ctx.key() + logger.info( + 'Adding invocation id %s for task %s', invocation_id, task_id + ) + current_invocation_id = await ctx.get(INVOCATION_ID) + if current_invocation_id is not None: + raise restate.exceptions.TerminalError( + 'There is an ongoing invocation. How did we end up here?' + ) + ctx.set(INVOCATION_ID, invocation_id) + + @staticmethod + async def upsert_task( + ctx: restate.ObjectContext, task_send_params: TaskSendParams + ) -> Task: + task_id = ctx.key() + logger.info('Upserting task %s', task_id) + + task_state = await ctx.get(TASK, type_hint=Task) + if task_state is None: + task_state = await ctx.run( + 'Create task', + lambda run_params=task_send_params: Task( + id=run_params.id, + sessionId=run_params.sessionId, + status=TaskStatus( + state=TaskState.SUBMITTED, timestamp=datetime.now() + ), + history=[run_params.message] + if run_params.message + else [], + ), + type_hint=Task, + ) + else: + task_state.history.append(task_send_params.message) + + ctx.set(TASK, task_state) + return task_state + + class A2aService: + @a2a_service.handler() + @staticmethod + async def process_request( + ctx: restate.Context, req: JSONRPCRequest + ) -> JSONRPCResponse: + methods = { + GetTaskRequest: A2aService.on_get_task, + SendTaskRequest: A2aService.on_send_task_request, + CancelTaskRequest: A2aService.on_cancel_task, + SendTaskStreamingRequest: A2aService.on_send_task_subscribe, + SetTaskPushNotificationRequest: A2aService.on_set_task_push_notification, + GetTaskPushNotificationRequest: A2aService.on_get_task_push_notification, + } + + try: + json_rpc_request = A2ARequest.validate_python(req.model_dump()) + except Exception as e: + logger.error('Error validating request: %s', e) + return JSONRPCResponse( + id=req.id, + error=JSONRPCError( + code=400, message='Invalid request format' + ), + ) + fn = methods.get(type(json_rpc_request), None) + if not fn: + return JSONRPCResponse( + id=req.id, + error=JSONRPCError(code=400, message='Method not found'), + ) + try: + return await fn(ctx, json_rpc_request) + except restate.exceptions.TerminalError as e: + logger.error('Error processing request: %s', e) + return JSONRPCResponse( + id=req.id, + error=JSONRPCError(code=e.status_code, message=e.message), + ) + + @staticmethod + async def on_send_task_request( + ctx: restate.Context, request: SendTaskRequest + ) -> SendTaskResponse: + logger.info('Sending task %s', request.params.id) + task_send_params: TaskSendParams = request.params + + return await ctx.object_call( + TaskObject.handle_send_task_request, + key=task_send_params.id, + arg=request, + idempotency_key=str(request.id), + ) + + @staticmethod + async def on_get_task( + ctx: restate.Context, request: GetTaskRequest + ) -> GetTaskResponse: + logger.info('Getting task %s', request.params.id) + task_query_params: TaskQueryParams = request.params + + task = await ctx.object_call( + TaskObject.get_task, key=task_query_params.id, arg=None + ) + if task is None: + return GetTaskResponse(id=request.id, error=TaskNotFoundError()) + + task_result = task.model_copy() + history_length = task_query_params.historyLength + if history_length is not None and history_length > 0: + task_result.history = task.history[-history_length:] + else: + # Default is no history + task_result.history = [] + return GetTaskResponse(id=request.id, result=task_result) + + @staticmethod + async def on_cancel_task( + ctx: restate.Context, request: CancelTaskRequest + ) -> CancelTaskResponse: + logger.info('Cancelling task %s', request.params.id) + task_id_params: TaskIdParams = request.params + + task = await ctx.object_call( + TaskObject.get_task, key=task_id_params.id, arg=None + ) + if task is None: + return CancelTaskResponse( + id=request.id, error=TaskNotFoundError() + ) + invocation_id = await ctx.object_call( + TaskObject.get_invocation_id, key=task_id_params.id, arg=None + ) + if invocation_id is None: + # Task either doesn't exist or is already completed + return await ctx.object_call( + TaskObject.cancel_task, key=task_id_params.id, arg=request + ) + + # Cancel the invocation + ctx.cancel_invocation(invocation_id) + # Wait for cancellation to complete and for the cancelled task info + canceled_task_info = await ctx.attach_invocation( + invocation_id, type_hint=SendTaskResponse + ) + return CancelTaskResponse( + id=request.id, + result=canceled_task_info.result, + ) + + @staticmethod + async def on_set_task_push_notification( + ctx: restate.Context, request: SetTaskPushNotificationRequest + ) -> SetTaskPushNotificationResponse: + raise restate.exceptions.TerminalError( + f'Not implemented: {request.method}' + ) + + @staticmethod + async def on_get_task_push_notification( + ctx: restate.Context, request: GetTaskPushNotificationRequest + ) -> GetTaskPushNotificationResponse: + raise restate.exceptions.TerminalError( + f'Not implemented: {request.method}' + ) + + @staticmethod + async def on_send_task_subscribe( + ctx: restate.Context, request: SendTaskStreamingRequest + ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: + raise restate.exceptions.TerminalError( + f'Not implemented: {request.method}' + ) + + @staticmethod + async def on_resubscribe_to_task( + ctx: restate.Context, request: TaskResubscriptionRequest + ) -> AsyncIterable[SendTaskResponse] | JSONRPCResponse: + raise restate.exceptions.TerminalError( + status_code=500, message=f'Not implemented: {request.method}' + ) + + return a2a_service, task_object + + +def _get_user_query(task_send_params: TaskSendParams) -> str: + part = task_send_params.message.parts[0] + if not isinstance(part, TextPart): + raise restate.exceptions.TerminalError('Only text parts are supported') + return part.text diff --git a/samples/python/agents/restate/pyproject.toml b/samples/python/agents/restate/pyproject.toml new file mode 100644 index 000000000..e3a33051e --- /dev/null +++ b/samples/python/agents/restate/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "a2a-samples-restate" +version = "0.1.0" +description = "Resilient A2A agent using the Restate SDK" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "a2a-samples", + "restate-sdk[serde]>=0.7.1", + "hypercorn", + "google-adk>=0.0.3", + "google-genai>=1.9.0", + "fastapi" +] + +[tool.uv.sources] +a2a-samples = { workspace = true } diff --git a/samples/python/pyproject.toml b/samples/python/pyproject.toml index b8a33a035..69c1463f7 100644 --- a/samples/python/pyproject.toml +++ b/samples/python/pyproject.toml @@ -30,7 +30,8 @@ members = [ "agents/llama_index_file_chat", "agents/semantickernel", "agents/mindsdb", - "agents/ag2" + "agents/ag2", + "agents/restate" ] [build-system] diff --git a/samples/python/uv.lock b/samples/python/uv.lock index 9e0de9876..8f4e4aba3 100644 --- a/samples/python/uv.lock +++ b/samples/python/uv.lock @@ -19,6 +19,7 @@ members = [ "a2a-samples-image-gen", "a2a-samples-marvin", "a2a-samples-mcp", + "a2a-samples-restate", "a2a-semantic-kernel", ] @@ -233,6 +234,29 @@ requires-dist = [ { name = "google-genai", specifier = ">=1.10.0" }, ] +[[package]] +name = "a2a-samples-restate" +version = "0.1.0" +source = { virtual = "agents/restate" } +dependencies = [ + { name = "a2a-samples" }, + { name = "fastapi" }, + { name = "google-adk" }, + { name = "google-genai" }, + { name = "hypercorn" }, + { name = "restate-sdk", extra = ["serde"] }, +] + +[package.metadata] +requires-dist = [ + { name = "a2a-samples", editable = "." }, + { name = "fastapi" }, + { name = "google-adk", specifier = ">=0.0.3" }, + { name = "google-genai", specifier = ">=1.9.0" }, + { name = "hypercorn" }, + { name = "restate-sdk", extras = ["serde"], specifier = ">=0.7.1" }, +] + [[package]] name = "a2a-semantic-kernel" version = "0.1.0" @@ -1046,6 +1070,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375 }, ] +[[package]] +name = "dacite" +version = "1.9.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600 }, +] + [[package]] name = "dataclasses-json" version = "0.6.7" @@ -1958,6 +1991,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 }, ] +[[package]] +name = "hypercorn" +version = "0.17.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, + { name = "h2" }, + { name = "priority" }, + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/3a/df6c27642e0dcb7aff688ca4be982f0fb5d89f2afd3096dc75347c16140f/hypercorn-0.17.3.tar.gz", hash = "sha256:1b37802ee3ac52d2d85270700d565787ab16cf19e1462ccfa9f089ca17574165", size = 44409 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/3b/dfa13a8d96aa24e40ea74a975a9906cfdc2ab2f4e3b498862a57052f04eb/hypercorn-0.17.3-py3-none-any.whl", hash = "sha256:059215dec34537f9d40a69258d323f56344805efb462959e727152b0aa504547", size = 61742 }, +] + [[package]] name = "hyperframe" version = "6.1.0" @@ -3644,6 +3692,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/a8/fc509e514c708f43102542cdcbc2f42dc49f7a159f90f56d072371629731/prance-25.4.8.0-py3-none-any.whl", hash = "sha256:d3c362036d625b12aeee495621cb1555fd50b2af3632af3d825176bfb50e073b", size = 36386 }, ] +[[package]] +name = "priority" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/3c/eb7c35f4dcede96fca1842dac5f4f5d15511aa4b52f3a961219e68ae9204/priority-2.0.0.tar.gz", hash = "sha256:c965d54f1b8d0d0b19479db3924c7c36cf672dbf2aec92d43fbdaf4492ba18c0", size = 24792 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/5f/82c8074f7e84978129347c2c6ec8b6c59f3584ff1a20bc3c940a3e061790/priority-2.0.0-py3-none-any.whl", hash = "sha256:6f8eefce5f3ad59baf2c080a664037bb4725cd0a790d53d59ab4059288faf6aa", size = 8946 }, +] + [[package]] name = "prompt-toolkit" version = "3.0.51" @@ -4421,6 +4478,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481 }, ] +[[package]] +name = "restate-sdk" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/98/840dac505dd0ec687b17e66963e8a17b0ca48c7a38cca34b3a49467de591/restate_sdk-0.7.1.tar.gz", hash = "sha256:bdbc6999838135896c3d809d02828251b6af33bac78ab55f908630b89472c203", size = 59163 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/56/23aa19ad0441645543e16743411b80e8e956188d1122d67cb741000761c1/restate_sdk-0.7.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7d215044e4c762aba95b375daa65de2b0cf5f9400ab046c825751f42030b9819", size = 1790030 }, + { url = "https://files.pythonhosted.org/packages/ed/2f/0e467831b66d47f93dae51b4b1fe59cc1c6e4a1906a5fe87b82231c81ccb/restate_sdk-0.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ed5d7129ea4f78f6132933c075eeb1055b603f54ee7a25fafb4a30adc207033", size = 1714262 }, + { url = "https://files.pythonhosted.org/packages/a0/a8/831934f62899a31872cec077f262a13fd9c5c278dd5343f8721c0b34f7fa/restate_sdk-0.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f540b7aa54cb43bb6f6393be3612f8d86685aa7983108aef51cf7a3b853a139e", size = 1973513 }, + { url = "https://files.pythonhosted.org/packages/64/77/b4aab0c1742c996013ef72276f35b077b365d2b35b2b61c7d342594979d6/restate_sdk-0.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d5e8bdc691584801dd8d0076101b6bb9bece8027fe448f4a109a31d318a213ea", size = 1938965 }, + { url = "https://files.pythonhosted.org/packages/43/f7/077a0b0e6f8ee4861df1c62e4cc25ac5a7c0c270b335687a477f96b3c9f9/restate_sdk-0.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f246aa753cc6190cff2aa84a0c1873e8eef983242f0111cf8b625d5415c6a6ef", size = 2126752 }, + { url = "https://files.pythonhosted.org/packages/60/24/e2dbdccd143d859c52d723c7034c0c716ec3790408b460512de34cf77888/restate_sdk-0.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:641cb5d628edfba9ffc7b94f93203b9c89b4988f35b1e70fe6e11a45872e6edd", size = 2140057 }, + { url = "https://files.pythonhosted.org/packages/e1/95/1eadab39a71d90b725a3f111130004ed3e1cac51af59679b3031656fb8d0/restate_sdk-0.7.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:c7b844344b441574e96240e5dcd66c2f63eb007c7177166737ae2b08f4151102", size = 1790275 }, + { url = "https://files.pythonhosted.org/packages/46/1a/efc21da414dce524a12f723221df6da3987ea481675faac95f0020212e26/restate_sdk-0.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a39d89a76e4ed84deeae19ac7b6ba2e78400ff5640408ebd385dbdb84c49a385", size = 1714709 }, + { url = "https://files.pythonhosted.org/packages/df/c2/c39ff95197638f1cfdb6349eb856e2487f6c7a3386019fe42d7f72331af7/restate_sdk-0.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e07b6288f0dd3f0e5cfbda0ec23fa5f46cd9f051736de3dd0488d433879759f5", size = 1973842 }, + { url = "https://files.pythonhosted.org/packages/49/a7/85b2682f7fe2fbe541186a3222f65fc8a2a08efbfa32c9fa51aa6a45c00c/restate_sdk-0.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:188c9a9dc8b0a9bce59009ef20e5af883db6d0e9ec168e0980dea87e0564a9ea", size = 1939231 }, + { url = "https://files.pythonhosted.org/packages/f0/70/8f5b20f24e680781d6f7a8d7ff8f5c0519f2261a774f865b24c7bf6d24b5/restate_sdk-0.7.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3d44b27449bc6fba8d8ab72efe01a1cde91b21b80f6b2774ba8532b30839d2d2", size = 2126751 }, + { url = "https://files.pythonhosted.org/packages/f8/e5/79d64697f6a77980be2a2b93a3519f916ef567e66944a6b3de6bef1b0b5c/restate_sdk-0.7.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d67ba4a9d1cc4831312254ea9cf9dad6b54e7cff0b469d5bd20300c6739d56dd", size = 2140418 }, + { url = "https://files.pythonhosted.org/packages/c4/d6/279d05df6c7fe2849b3427efdbcd82400396448381cdf77df6a918fbba79/restate_sdk-0.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:8c5b97ac27ed80828f4388768367f254271c39552e25ec943aaacd5f0a12064e", size = 1939642 }, + { url = "https://files.pythonhosted.org/packages/cb/c2/ab19491e76d3f43976c432b5405fb384a3c2c66f325ec04f5a072a93b4e5/restate_sdk-0.7.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c893e762cdf8ecb994df75ebc3132803748f4ee1387795be12b4e838160268d1", size = 2127237 }, + { url = "https://files.pythonhosted.org/packages/cf/f5/dca03eb140b00d85441f20f71799357c11ce848f38d7fc48925755763bcb/restate_sdk-0.7.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:24eac2aae43ba5b433d75999024ec59e64717fe0b237d7c891d1fe14627d82ca", size = 2139918 }, +] + +[package.optional-dependencies] +serde = [ + { name = "dacite" }, + { name = "pydantic" }, +] + [[package]] name = "rfc3339-validator" version = "0.1.4" @@ -5265,6 +5351,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594 }, ] +[[package]] +name = "wsproto" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226 }, +] + [[package]] name = "xxhash" version = "3.5.0"