diff --git a/README.md b/README.md
index e9ef85718..060c3e790 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ As AI agents become more prevalent, their ability to interoperate is crucial for
- [Marvin](/samples/python/agents/marvin/README.md)
- [Semantic Kernel](/samples/python/agents/semantickernel/README.md)
- [AG2 + MCP](/samples/python/agents/ag2/README.md)
+ - [Restate](/samples/python/agents/restate/README.md)
## Contributing
diff --git a/samples/python/agents/restate/README.md b/samples/python/agents/restate/README.md
new file mode 100644
index 000000000..3a43c6575
--- /dev/null
+++ b/samples/python/agents/restate/README.md
@@ -0,0 +1,83 @@
+## Resilient Agents with Restate
+
+This sample uses [Restate](https://ai.restate.dev/) and the Agent Development Kit (ADK) to create a resilient "Expense Reimbursement" agent that is hosted as an A2A server.
+
+Restate lets you build resilient applications easily. It provides a distributed durable version of your everyday building blocks.
+
+In this example, Restate acts as a scalable, resilient task orchestrator that speaks the A2A protocol and gives you:
+- 🔁 **Automatic retries** - Handles LLM API downtime, timeouts, and infrastructure failures
+- 🔄 **Smart recovery** - Preserves progress across failures without duplicating work
+- ⏱️ **Persistent task handles** - Tracks progress across failures, time, and processes
+- 🎮 **Task control** - Cancel tasks, query status, re-subscribe to ongoing tasks
+- 🧠 **Idempotent submission** - Automatic deduplication based on task ID
+- 🤖 **Agentic workflows** - Build resilient agents with human-in-the-loop and parallel tool execution
+- 💾 **Durable state** - Maintain consistent agent state across infrastructure events
+- 👀 **Full observability** - Line-by-line execution tracking with built-in audit trail
+- ☁️️ **Easy to self-host** - or connect to Restate Cloud
+
+
+
+This agent takes text requests from the client and, if any details are missing, returns a webform for the client (or its user) to fill out.
+After the client fills out the form, the agent will complete the task.
+
+## Prerequisites
+
+- Python 3.12 or higher
+- [UV](https://docs.astral.sh/uv/)
+- Access to an LLM and API Key
+
+
+## Running the Sample
+
+1. Navigate to the samples directory:
+ ```bash
+ cd samples/python/agents/restate
+ ```
+2. Create an environment file with your API key:
+
+ ```bash
+ echo "GOOGLE_API_KEY=your_api_key_here" > .env
+ ```
+
+4. Run the A2A server and agent:
+ ```bash
+ uv run .
+ ```
+
+6. Start the Restate Server with Docker ([for other options check the docs](https://docs.restate.dev/develop/local_dev#running-restate-server--cli-locally)).
+
+ ```shell
+ docker run -p 8080:8080 -p 9070:9070 \
+ --add-host=host.docker.internal:host-gateway \
+ -e 'RESTATE_WORKER__INVOKER__INACTIVITY_TIMEOUT=5min' \
+ -e 'RESTATE_WORKER__INVOKER__ABORT_TIMEOUT=5min' \
+ docker.restate.dev/restatedev/restate:latest
+ ```
+
+ Let Restate know where the A2A server is running:
+ ```shell
+ docker run -it --network=host docker.restate.dev/restatedev/restate-cli:latest \
+ deployments register http://host.docker.internal:9080/restate/v1
+ ```
+
+5. In a separate terminal, run the A2A client:
+ ```
+ # Connect to the agent (specify the agent URL with correct port)
+ cd samples/python/hosts/cli
+ uv run . --agent http://localhost:9080
+
+ # If you changed the port when starting the agent, use that port instead
+ # uv run . --agent http://localhost:YOUR_PORT
+ ```
+
+6. Send requests with the A2A client like: `Reimburse my flight of 700 USD`
+
+Open the Restate UI ([http://localhost:9070](http://localhost:9070)) to see the task execution log and the task state.
+
+
+
+
+# Learn more
+- [Restate Website](https://restate.dev/)
+- [Restate Documentation](https://docs.restate.dev/)
+- [Restate GitHub repo](https://github.com/restatedev/restate)
\ No newline at end of file
diff --git a/samples/python/agents/restate/__main__.py b/samples/python/agents/restate/__main__.py
new file mode 100644
index 000000000..72f107b33
--- /dev/null
+++ b/samples/python/agents/restate/__main__.py
@@ -0,0 +1,75 @@
+"""A an example of serving a resilient agent using restate.dev"""
+
+import os
+
+import restate
+
+from agent import ReimbursementAgent
+from common.types import (
+ AgentCapabilities,
+ AgentCard,
+ AgentSkill,
+ MissingAPIKeyError,
+)
+from dotenv import load_dotenv
+from fastapi import FastAPI
+from middleware import AgentMiddleware
+
+
+load_dotenv()
+
+RESTATE_HOST = os.getenv('RESTATE_HOST', 'http://localhost:8080')
+
+AGENT_CARD = AgentCard(
+ name='ReimbursementAgent',
+ description='This agent handles the reimbursement process for the employees given the amount and purpose of the reimbursement.',
+ url=RESTATE_HOST,
+ version='1.0.0',
+ defaultInputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES,
+ defaultOutputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES,
+ capabilities=AgentCapabilities(streaming=False),
+ skills=[
+ AgentSkill(
+ id='process_reimbursement',
+ name='Process Reimbursement Tool',
+ description='Helps with the reimbursement process for users given the amount and purpose of the reimbursement.',
+ tags=['reimbursement'],
+ examples=[
+ 'Can you reimburse me $20 for my lunch with the clients?'
+ ],
+ )
+ ],
+)
+
+REIMBURSEMENT_AGENT = AgentMiddleware(AGENT_CARD, ReimbursementAgent())
+
+app = FastAPI()
+
+
+@app.get('/.well-known/agent.json')
+async def agent_json():
+ """Serve the agent card"""
+ return REIMBURSEMENT_AGENT.agent_card_json
+
+
+app.mount('/restate/v1', restate.app(REIMBURSEMENT_AGENT))
+
+
+def main():
+ """Serve the agent at a specified port using hypercorn."""
+ import asyncio
+
+ import hypercorn
+ import hypercorn.asyncio
+
+ if not os.getenv('GOOGLE_API_KEY'):
+ raise MissingAPIKeyError('GOOGLE_API_KEY environment variable not set.')
+
+ port = os.getenv('AGENT_PORT', '9080')
+ conf = hypercorn.Config()
+ conf.bind = [f'0.0.0.0:{port}']
+ asyncio.run(hypercorn.asyncio.serve(app, conf))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/samples/python/agents/restate/agent.py b/samples/python/agents/restate/agent.py
new file mode 100644
index 000000000..2b20ad189
--- /dev/null
+++ b/samples/python/agents/restate/agent.py
@@ -0,0 +1,228 @@
+"""An agent that handles reimbursement requests. Pretty much a copy of the
+reimbursement agent from this repo, just made the tools a bit more interesting.
+"""
+
+import json
+import logging
+import random
+
+from typing import Any, Optional
+
+from agents.restate.middleware import AgentInvokeResult
+from common.types import TextPart
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.artifacts import InMemoryArtifactService
+from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
+from google.adk.runners import Runner
+from google.adk.sessions import InMemorySessionService
+from google.adk.tools.tool_context import ToolContext
+from google.genai import types
+
+
+logger = logging.getLogger(__name__)
+
+
+# Local cache of created request_ids for demo purposes.
+request_ids = set()
+
+
+def create_request_form(
+ date: Optional[str] = None,
+ amount: Optional[str] = None,
+ purpose: Optional[str] = None,
+) -> dict[str, Any]:
+ """Create a request form for the employee to fill out.
+
+ Args:
+ date (str): The date of the request. Can be an empty string.
+ amount (str): The requested amount. Can be an empty string.
+ purpose (str): The purpose of the request. Can be an empty string.
+
+ Returns:
+ dict[str, Any]: A dictionary containing the request form data.
+ """
+ logger.info('Creating reimbursement request')
+ request_id = 'request_id_' + str(random.randint(1000000, 9999999))
+ request_ids.add(request_id)
+ reimbursement = {
+ 'request_id': request_id,
+ 'date': '' if not date else date,
+ 'amount': '' if not amount else amount,
+ 'purpose': (
+ ''
+ if not purpose
+ else purpose
+ ),
+ }
+ logger.info('Reimbursement request created: %s', json.dumps(reimbursement))
+
+ return reimbursement
+
+
+def return_form(
+ form_request: dict[str, Any],
+ tool_context: ToolContext,
+ instructions: Optional[str] = None,
+) -> dict[str, Any]:
+ """Returns a structured json object indicating a form to complete.
+
+ Args:
+ form_request (dict[str, Any]): The request form data.
+ tool_context (ToolContext): The context in which the tool operates.
+ instructions (str): Instructions for processing the form. Can be an empty string.
+
+ Returns:
+ dict[str, Any]: A JSON dictionary for the form response.
+ """
+ logger.info('Creating return form')
+ if isinstance(form_request, str):
+ form_request = json.loads(form_request)
+
+ form_dict = {
+ 'type': 'form',
+ 'form': {
+ 'type': 'object',
+ 'properties': {
+ 'date': {
+ 'type': 'string',
+ 'format': 'date',
+ 'description': 'Date of expense',
+ 'title': 'Date',
+ },
+ 'amount': {
+ 'type': 'string',
+ 'format': 'number',
+ 'description': 'Amount of expense',
+ 'title': 'Amount',
+ },
+ 'purpose': {
+ 'type': 'string',
+ 'description': 'Purpose of expense',
+ 'title': 'Purpose',
+ },
+ 'request_id': {
+ 'type': 'string',
+ 'description': 'Request id',
+ 'title': 'Request ID',
+ },
+ },
+ 'required': list(form_request.keys()),
+ },
+ 'form_data': form_request,
+ 'instructions': instructions,
+ }
+ logger.info('Return form created: %s', json.dumps(form_dict))
+ return json.dumps(form_dict)
+
+
+async def reimburse(request_id: str) -> dict[str, Any]:
+ """Reimburse the amount of money to the employee for a given request_id."""
+ logger.info('Starting reimbursement: %s', request_id)
+ if request_id not in request_ids:
+ return {
+ 'request_id': request_id,
+ 'status': 'Error: Invalid request_id.',
+ }
+ logger.info('Reimbursement approved: %s', request_id)
+ return {'request_id': request_id, 'status': 'approved'}
+
+
+class ReimbursementAgent:
+ """An agent that handles reimbursement requests."""
+
+ SUPPORTED_CONTENT_TYPES = ['text', 'text/plain']
+
+ def __init__(self):
+ self._agent = self._build_agent()
+ self._user_id = 'remote_agent'
+ self._runner = Runner(
+ app_name=self._agent.name,
+ agent=self._agent,
+ artifact_service=InMemoryArtifactService(),
+ session_service=InMemorySessionService(),
+ memory_service=InMemoryMemoryService(),
+ )
+
+ async def invoke(self, query, session_id) -> AgentInvokeResult:
+ logger.info('Invoking LLM')
+ session = self._runner.session_service.get_session(
+ app_name=self._agent.name,
+ user_id=self._user_id,
+ session_id=session_id,
+ )
+ content = types.Content(
+ role='user', parts=[types.Part.from_text(text=query)]
+ )
+ if session is None:
+ self._runner.session_service.create_session(
+ app_name=self._agent.name,
+ user_id=self._user_id,
+ state={},
+ session_id=session_id,
+ )
+
+ events = []
+ async for event in self._runner.run_async(
+ user_id=self._user_id,
+ session_id=session_id,
+ new_message=content,
+ ):
+ events.append(event)
+
+ logger.info('LLM response: %s', events)
+ if not events or not events[-1].content or not events[-1].content.parts:
+ return AgentInvokeResult(
+ parts=[TextPart(text='')],
+ require_user_input=False,
+ is_task_complete=True,
+ )
+ return AgentInvokeResult(
+ parts=[
+ TextPart(
+ text='\n'.join(
+ [p.text for p in events[-1].content.parts if p.text]
+ )
+ )
+ ],
+ require_user_input=False,
+ is_task_complete=True,
+ )
+
+ def _build_agent(self) -> LlmAgent:
+ """Builds the LLM agent for the reimbursement agent."""
+ return LlmAgent(
+ model='gemini-2.0-flash-001',
+ name='reimbursement_agent',
+ description=(
+ 'This agent handles the reimbursement process for the employees'
+ ' given the amount and purpose of the reimbursement.'
+ ),
+ instruction="""
+ You are an agent who handle the reimbursement process for employees.
+
+ When you receive an reimbursement request, you should first create a new request form using create_request_form(). Only provide default values if they are provided by the user, otherwise use an empty string as the default value.
+ 1. 'Date': the date of the transaction.
+ 2. 'Amount': the dollar amount of the transaction.
+ 3. 'Business Justification/Purpose': the reason for the reimbursement.
+
+ Once you created the form, you should return the result of calling return_form with the form data from the create_request_form call.
+ Clearly let the user know which fields are required and missing.
+
+ Once you received the filled-out form back from the user, you should then check the form contains all required information:
+ 1. 'Date': the date of the transaction.
+ 2. 'Amount': the value of the amount of the reimbursement being requested.
+ 3. 'Business Justification/Purpose': the item/object/artifact of the reimbursement.
+
+ If you don't have all of the information, you should reject the request directly by calling the request_form method, providing the missing fields.
+
+
+ For valid reimbursement requests, you can then use reimburse() to reimburse the employee.
+ * In your response, you should include the request_id and the status of the reimbursement request.
+
+ """,
+ tools=[
+ create_request_form,
+ reimburse,
+ return_form,
+ ],
+ )
diff --git a/samples/python/agents/restate/middleware.py b/samples/python/agents/restate/middleware.py
new file mode 100644
index 000000000..9a70f69ab
--- /dev/null
+++ b/samples/python/agents/restate/middleware.py
@@ -0,0 +1,457 @@
+# pylint: disable=C0116
+import logging
+import uuid
+
+from collections.abc import AsyncIterable, Iterable
+from datetime import datetime
+
+import restate
+
+from common.types import (
+ A2ARequest,
+ AgentCard,
+ Artifact,
+ CancelTaskRequest,
+ CancelTaskResponse,
+ GetTaskPushNotificationRequest,
+ GetTaskPushNotificationResponse,
+ GetTaskRequest,
+ GetTaskResponse,
+ JSONRPCError,
+ JSONRPCRequest,
+ JSONRPCResponse,
+ Message,
+ Part,
+ SendTaskRequest,
+ SendTaskResponse,
+ SendTaskStreamingRequest,
+ SendTaskStreamingResponse,
+ SetTaskPushNotificationRequest,
+ SetTaskPushNotificationResponse,
+ Task,
+ TaskIdParams,
+ TaskNotFoundError,
+ TaskQueryParams,
+ TaskResubscriptionRequest,
+ TaskSendParams,
+ TaskState,
+ TaskStatus,
+ TextPart,
+)
+from pydantic import BaseModel
+from restate.serde import PydanticJsonSerde
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format='[%(asctime)s] [%(process)d] [%(levelname)s] - %(message)s',
+)
+logger = logging.getLogger(__name__)
+
+
+# MODELS
+
+
+class AgentInvokeResult(BaseModel):
+ """Result of the agent invocation."""
+
+ parts: list[Part]
+ require_user_input: bool
+ is_task_complete: bool
+
+
+# K/V stored in Restate
+TASK = 'task'
+INVOCATION_ID = 'invocation-id'
+
+
+class AgentMiddleware(Iterable[restate.Service | restate.VirtualObject]):
+ """Middleware for the agent to handle task processing and state management."""
+
+ def __init__(self, agent_card: AgentCard, agent):
+ self.agent_card = agent_card.model_copy()
+ self.agent = agent
+ self.a2a_server_name = f'{self.agent_card.name}A2AServer'
+ self.task_object_name = f'{self.agent_card.name}TaskObject'
+
+ # replace the base url with the exact url of the process_request handler.
+ restate_base_url = self.agent_card.url
+ process_request_url = (
+ f'{restate_base_url}/{self.a2a_server_name}/process_request'
+ )
+ self.agent_card.url = process_request_url
+
+ self.restate_services = []
+ _build_services(self)
+
+ def __iter__(self):
+ """Returns the services that define the agent's a2a server and task object."""
+ return iter(self.restate_services)
+
+ @property
+ def agent_card_json(self):
+ """Return the agent card"""
+ return self.agent_card.model_dump()
+
+ @property
+ def services(self) -> Iterable[restate.Service | restate.VirtualObject]:
+ """Return the services that define the agent's a2a server and task object"""
+ return self.restate_services
+
+
+def _build_services(middleware: AgentMiddleware):
+ """Creates an A2A server for reimbursement processing with customizable name and description."""
+ a2a_service = restate.Service(
+ middleware.a2a_server_name,
+ description=middleware.agent_card.description,
+ metadata={
+ 'agent': middleware.agent_card.name,
+ 'version': middleware.agent_card.version,
+ },
+ )
+ middleware.restate_services.append(a2a_service)
+
+ task_object = restate.VirtualObject(middleware.task_object_name)
+ middleware.restate_services.append(task_object)
+
+ agent = middleware.agent
+
+ class TaskObject:
+ """TaskObject is a virtual object that handles task processing and state management."""
+
+ @staticmethod
+ @task_object.handler(kind='shared')
+ async def get_invocation_id(
+ ctx: restate.ObjectSharedContext,
+ ) -> str | None:
+ task_id = ctx.key()
+ logger.info('Getting invocation id for task %s', task_id)
+ return await ctx.get(INVOCATION_ID) or None
+
+ @staticmethod
+ @task_object.handler(
+ output_serde=PydanticJsonSerde(Task), kind='shared'
+ )
+ async def get_task(
+ ctx: restate.ObjectSharedContext,
+ ) -> Task | None:
+ task_id = ctx.key()
+ logger.info('Getting task %s', task_id)
+ return await ctx.get(TASK, type_hint=Task) or None
+
+ @staticmethod
+ @task_object.handler()
+ async def cancel_task(
+ ctx: restate.ObjectContext, request: CancelTaskRequest
+ ) -> CancelTaskResponse:
+ cancelled_task = await TaskObject.update_store(
+ ctx, state=TaskState.CANCELED
+ )
+ return CancelTaskResponse(id=request.id, result=cancelled_task)
+
+ @staticmethod
+ @task_object.handler()
+ async def handle_send_task_request(
+ ctx: restate.ObjectContext, request: SendTaskRequest
+ ) -> SendTaskResponse:
+ logger.info(
+ 'Starting task execution workflow %s for task %s',
+ request.id,
+ request.params.id,
+ )
+
+ task_send_params: TaskSendParams = request.params
+ if not task_send_params.sessionId:
+ session_id = await ctx.run(
+ 'Generate session id', lambda: str(uuid.uuid4().hex)
+ )
+ task_send_params.sessionId = session_id
+
+ # Store this invocation ID so it can be cancelled by someone else
+ await TaskObject.set_invocation_id(ctx, ctx.request().id)
+
+ # Persist the request data
+ await TaskObject.upsert_task(ctx, task_send_params)
+
+ try:
+ # Forward the request to the agent
+ result = await ctx.run(
+ 'Agent invoke',
+ agent.invoke,
+ args=(
+ _get_user_query(task_send_params),
+ task_send_params.sessionId,
+ ),
+ type_hint=AgentInvokeResult,
+ )
+
+ if result.require_user_input:
+ updated_task = await TaskObject.update_store(
+ ctx,
+ state=TaskState.INPUT_REQUIRED,
+ status_message=Message(
+ role='agent', parts=result.parts
+ ),
+ )
+ else:
+ updated_task = await TaskObject.update_store(
+ ctx,
+ state=TaskState.COMPLETED,
+ artifacts=[Artifact(parts=result.parts)],
+ )
+
+ ctx.clear(INVOCATION_ID)
+ return SendTaskResponse(id=request.id, result=updated_task)
+ except restate.exceptions.TerminalError as e:
+ if e.status_code == 409 and e.message == 'cancelled':
+ logger.info('Task %s was cancelled', task_send_params.id)
+ cancelled_task = await TaskObject.update_store(
+ ctx, state=TaskState.CANCELED
+ )
+ ctx.clear(INVOCATION_ID)
+ return SendTaskResponse(
+ id=request.id, result=cancelled_task
+ )
+
+ logger.error(
+ 'Error while processing task %s: %s', task_send_params.id, e
+ )
+ failed_task = await TaskObject.update_store(
+ ctx, state=TaskState.FAILED
+ )
+ ctx.clear(INVOCATION_ID)
+ return SendTaskResponse(id=request.id, result=failed_task)
+
+ @staticmethod
+ async def update_store(
+ ctx: restate.ObjectContext,
+ state: TaskState | None,
+ status_message: Message | None = None,
+ artifacts: list[Artifact] | None = None,
+ ) -> Task:
+ task_id = ctx.key()
+ logger.info('Updating status task %s to %s', task_id, state)
+
+ task = await ctx.get(TASK, type_hint=Task)
+ if task is None:
+ logger.error('Task %s not found for updating the task', task_id)
+ raise restate.exceptions.TerminalError(
+ f'Task {task_id} not found'
+ )
+
+ new_task_status = await ctx.run(
+ 'task status',
+ lambda task_state=state: TaskStatus(
+ state=task_state,
+ timestamp=datetime.now(),
+ message=status_message,
+ ),
+ type_hint=TaskStatus,
+ )
+ prev_status = task.status
+ if prev_status.message is not None:
+ task.history.append(prev_status.message)
+ task.status = new_task_status
+
+ if artifacts is not None:
+ if task.artifacts is None:
+ task.artifacts = []
+ task.artifacts.extend(artifacts)
+
+ ctx.set(TASK, task)
+ return task
+
+ @staticmethod
+ async def set_invocation_id(
+ ctx: restate.ObjectContext, invocation_id: str
+ ):
+ task_id = ctx.key()
+ logger.info(
+ 'Adding invocation id %s for task %s', invocation_id, task_id
+ )
+ current_invocation_id = await ctx.get(INVOCATION_ID)
+ if current_invocation_id is not None:
+ raise restate.exceptions.TerminalError(
+ 'There is an ongoing invocation. How did we end up here?'
+ )
+ ctx.set(INVOCATION_ID, invocation_id)
+
+ @staticmethod
+ async def upsert_task(
+ ctx: restate.ObjectContext, task_send_params: TaskSendParams
+ ) -> Task:
+ task_id = ctx.key()
+ logger.info('Upserting task %s', task_id)
+
+ task_state = await ctx.get(TASK, type_hint=Task)
+ if task_state is None:
+ task_state = await ctx.run(
+ 'Create task',
+ lambda run_params=task_send_params: Task(
+ id=run_params.id,
+ sessionId=run_params.sessionId,
+ status=TaskStatus(
+ state=TaskState.SUBMITTED, timestamp=datetime.now()
+ ),
+ history=[run_params.message]
+ if run_params.message
+ else [],
+ ),
+ type_hint=Task,
+ )
+ else:
+ task_state.history.append(task_send_params.message)
+
+ ctx.set(TASK, task_state)
+ return task_state
+
+ class A2aService:
+ @a2a_service.handler()
+ @staticmethod
+ async def process_request(
+ ctx: restate.Context, req: JSONRPCRequest
+ ) -> JSONRPCResponse:
+ methods = {
+ GetTaskRequest: A2aService.on_get_task,
+ SendTaskRequest: A2aService.on_send_task_request,
+ CancelTaskRequest: A2aService.on_cancel_task,
+ SendTaskStreamingRequest: A2aService.on_send_task_subscribe,
+ SetTaskPushNotificationRequest: A2aService.on_set_task_push_notification,
+ GetTaskPushNotificationRequest: A2aService.on_get_task_push_notification,
+ }
+
+ try:
+ json_rpc_request = A2ARequest.validate_python(req.model_dump())
+ except Exception as e:
+ logger.error('Error validating request: %s', e)
+ return JSONRPCResponse(
+ id=req.id,
+ error=JSONRPCError(
+ code=400, message='Invalid request format'
+ ),
+ )
+ fn = methods.get(type(json_rpc_request), None)
+ if not fn:
+ return JSONRPCResponse(
+ id=req.id,
+ error=JSONRPCError(code=400, message='Method not found'),
+ )
+ try:
+ return await fn(ctx, json_rpc_request)
+ except restate.exceptions.TerminalError as e:
+ logger.error('Error processing request: %s', e)
+ return JSONRPCResponse(
+ id=req.id,
+ error=JSONRPCError(code=e.status_code, message=e.message),
+ )
+
+ @staticmethod
+ async def on_send_task_request(
+ ctx: restate.Context, request: SendTaskRequest
+ ) -> SendTaskResponse:
+ logger.info('Sending task %s', request.params.id)
+ task_send_params: TaskSendParams = request.params
+
+ return await ctx.object_call(
+ TaskObject.handle_send_task_request,
+ key=task_send_params.id,
+ arg=request,
+ idempotency_key=str(request.id),
+ )
+
+ @staticmethod
+ async def on_get_task(
+ ctx: restate.Context, request: GetTaskRequest
+ ) -> GetTaskResponse:
+ logger.info('Getting task %s', request.params.id)
+ task_query_params: TaskQueryParams = request.params
+
+ task = await ctx.object_call(
+ TaskObject.get_task, key=task_query_params.id, arg=None
+ )
+ if task is None:
+ return GetTaskResponse(id=request.id, error=TaskNotFoundError())
+
+ task_result = task.model_copy()
+ history_length = task_query_params.historyLength
+ if history_length is not None and history_length > 0:
+ task_result.history = task.history[-history_length:]
+ else:
+ # Default is no history
+ task_result.history = []
+ return GetTaskResponse(id=request.id, result=task_result)
+
+ @staticmethod
+ async def on_cancel_task(
+ ctx: restate.Context, request: CancelTaskRequest
+ ) -> CancelTaskResponse:
+ logger.info('Cancelling task %s', request.params.id)
+ task_id_params: TaskIdParams = request.params
+
+ task = await ctx.object_call(
+ TaskObject.get_task, key=task_id_params.id, arg=None
+ )
+ if task is None:
+ return CancelTaskResponse(
+ id=request.id, error=TaskNotFoundError()
+ )
+ invocation_id = await ctx.object_call(
+ TaskObject.get_invocation_id, key=task_id_params.id, arg=None
+ )
+ if invocation_id is None:
+ # Task either doesn't exist or is already completed
+ return await ctx.object_call(
+ TaskObject.cancel_task, key=task_id_params.id, arg=request
+ )
+
+ # Cancel the invocation
+ ctx.cancel_invocation(invocation_id)
+ # Wait for cancellation to complete and for the cancelled task info
+ canceled_task_info = await ctx.attach_invocation(
+ invocation_id, type_hint=SendTaskResponse
+ )
+ return CancelTaskResponse(
+ id=request.id,
+ result=canceled_task_info.result,
+ )
+
+ @staticmethod
+ async def on_set_task_push_notification(
+ ctx: restate.Context, request: SetTaskPushNotificationRequest
+ ) -> SetTaskPushNotificationResponse:
+ raise restate.exceptions.TerminalError(
+ f'Not implemented: {request.method}'
+ )
+
+ @staticmethod
+ async def on_get_task_push_notification(
+ ctx: restate.Context, request: GetTaskPushNotificationRequest
+ ) -> GetTaskPushNotificationResponse:
+ raise restate.exceptions.TerminalError(
+ f'Not implemented: {request.method}'
+ )
+
+ @staticmethod
+ async def on_send_task_subscribe(
+ ctx: restate.Context, request: SendTaskStreamingRequest
+ ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
+ raise restate.exceptions.TerminalError(
+ f'Not implemented: {request.method}'
+ )
+
+ @staticmethod
+ async def on_resubscribe_to_task(
+ ctx: restate.Context, request: TaskResubscriptionRequest
+ ) -> AsyncIterable[SendTaskResponse] | JSONRPCResponse:
+ raise restate.exceptions.TerminalError(
+ status_code=500, message=f'Not implemented: {request.method}'
+ )
+
+ return a2a_service, task_object
+
+
+def _get_user_query(task_send_params: TaskSendParams) -> str:
+ part = task_send_params.message.parts[0]
+ if not isinstance(part, TextPart):
+ raise restate.exceptions.TerminalError('Only text parts are supported')
+ return part.text
diff --git a/samples/python/agents/restate/pyproject.toml b/samples/python/agents/restate/pyproject.toml
new file mode 100644
index 000000000..e3a33051e
--- /dev/null
+++ b/samples/python/agents/restate/pyproject.toml
@@ -0,0 +1,17 @@
+[project]
+name = "a2a-samples-restate"
+version = "0.1.0"
+description = "Resilient A2A agent using the Restate SDK"
+readme = "README.md"
+requires-python = ">=3.12"
+dependencies = [
+ "a2a-samples",
+ "restate-sdk[serde]>=0.7.1",
+ "hypercorn",
+ "google-adk>=0.0.3",
+ "google-genai>=1.9.0",
+ "fastapi"
+]
+
+[tool.uv.sources]
+a2a-samples = { workspace = true }
diff --git a/samples/python/pyproject.toml b/samples/python/pyproject.toml
index b8a33a035..69c1463f7 100644
--- a/samples/python/pyproject.toml
+++ b/samples/python/pyproject.toml
@@ -30,7 +30,8 @@ members = [
"agents/llama_index_file_chat",
"agents/semantickernel",
"agents/mindsdb",
- "agents/ag2"
+ "agents/ag2",
+ "agents/restate"
]
[build-system]
diff --git a/samples/python/uv.lock b/samples/python/uv.lock
index 9e0de9876..8f4e4aba3 100644
--- a/samples/python/uv.lock
+++ b/samples/python/uv.lock
@@ -19,6 +19,7 @@ members = [
"a2a-samples-image-gen",
"a2a-samples-marvin",
"a2a-samples-mcp",
+ "a2a-samples-restate",
"a2a-semantic-kernel",
]
@@ -233,6 +234,29 @@ requires-dist = [
{ name = "google-genai", specifier = ">=1.10.0" },
]
+[[package]]
+name = "a2a-samples-restate"
+version = "0.1.0"
+source = { virtual = "agents/restate" }
+dependencies = [
+ { name = "a2a-samples" },
+ { name = "fastapi" },
+ { name = "google-adk" },
+ { name = "google-genai" },
+ { name = "hypercorn" },
+ { name = "restate-sdk", extra = ["serde"] },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "a2a-samples", editable = "." },
+ { name = "fastapi" },
+ { name = "google-adk", specifier = ">=0.0.3" },
+ { name = "google-genai", specifier = ">=1.9.0" },
+ { name = "hypercorn" },
+ { name = "restate-sdk", extras = ["serde"], specifier = ">=0.7.1" },
+]
+
[[package]]
name = "a2a-semantic-kernel"
version = "0.1.0"
@@ -1046,6 +1070,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375 },
]
+[[package]]
+name = "dacite"
+version = "1.9.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600 },
+]
+
[[package]]
name = "dataclasses-json"
version = "0.6.7"
@@ -1958,6 +1991,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 },
]
+[[package]]
+name = "hypercorn"
+version = "0.17.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "h11" },
+ { name = "h2" },
+ { name = "priority" },
+ { name = "wsproto" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/7e/3a/df6c27642e0dcb7aff688ca4be982f0fb5d89f2afd3096dc75347c16140f/hypercorn-0.17.3.tar.gz", hash = "sha256:1b37802ee3ac52d2d85270700d565787ab16cf19e1462ccfa9f089ca17574165", size = 44409 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0e/3b/dfa13a8d96aa24e40ea74a975a9906cfdc2ab2f4e3b498862a57052f04eb/hypercorn-0.17.3-py3-none-any.whl", hash = "sha256:059215dec34537f9d40a69258d323f56344805efb462959e727152b0aa504547", size = 61742 },
+]
+
[[package]]
name = "hyperframe"
version = "6.1.0"
@@ -3644,6 +3692,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a9/a8/fc509e514c708f43102542cdcbc2f42dc49f7a159f90f56d072371629731/prance-25.4.8.0-py3-none-any.whl", hash = "sha256:d3c362036d625b12aeee495621cb1555fd50b2af3632af3d825176bfb50e073b", size = 36386 },
]
+[[package]]
+name = "priority"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f5/3c/eb7c35f4dcede96fca1842dac5f4f5d15511aa4b52f3a961219e68ae9204/priority-2.0.0.tar.gz", hash = "sha256:c965d54f1b8d0d0b19479db3924c7c36cf672dbf2aec92d43fbdaf4492ba18c0", size = 24792 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5e/5f/82c8074f7e84978129347c2c6ec8b6c59f3584ff1a20bc3c940a3e061790/priority-2.0.0-py3-none-any.whl", hash = "sha256:6f8eefce5f3ad59baf2c080a664037bb4725cd0a790d53d59ab4059288faf6aa", size = 8946 },
+]
+
[[package]]
name = "prompt-toolkit"
version = "3.0.51"
@@ -4421,6 +4478,35 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481 },
]
+[[package]]
+name = "restate-sdk"
+version = "0.7.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/60/98/840dac505dd0ec687b17e66963e8a17b0ca48c7a38cca34b3a49467de591/restate_sdk-0.7.1.tar.gz", hash = "sha256:bdbc6999838135896c3d809d02828251b6af33bac78ab55f908630b89472c203", size = 59163 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/84/56/23aa19ad0441645543e16743411b80e8e956188d1122d67cb741000761c1/restate_sdk-0.7.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7d215044e4c762aba95b375daa65de2b0cf5f9400ab046c825751f42030b9819", size = 1790030 },
+ { url = "https://files.pythonhosted.org/packages/ed/2f/0e467831b66d47f93dae51b4b1fe59cc1c6e4a1906a5fe87b82231c81ccb/restate_sdk-0.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ed5d7129ea4f78f6132933c075eeb1055b603f54ee7a25fafb4a30adc207033", size = 1714262 },
+ { url = "https://files.pythonhosted.org/packages/a0/a8/831934f62899a31872cec077f262a13fd9c5c278dd5343f8721c0b34f7fa/restate_sdk-0.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f540b7aa54cb43bb6f6393be3612f8d86685aa7983108aef51cf7a3b853a139e", size = 1973513 },
+ { url = "https://files.pythonhosted.org/packages/64/77/b4aab0c1742c996013ef72276f35b077b365d2b35b2b61c7d342594979d6/restate_sdk-0.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d5e8bdc691584801dd8d0076101b6bb9bece8027fe448f4a109a31d318a213ea", size = 1938965 },
+ { url = "https://files.pythonhosted.org/packages/43/f7/077a0b0e6f8ee4861df1c62e4cc25ac5a7c0c270b335687a477f96b3c9f9/restate_sdk-0.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f246aa753cc6190cff2aa84a0c1873e8eef983242f0111cf8b625d5415c6a6ef", size = 2126752 },
+ { url = "https://files.pythonhosted.org/packages/60/24/e2dbdccd143d859c52d723c7034c0c716ec3790408b460512de34cf77888/restate_sdk-0.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:641cb5d628edfba9ffc7b94f93203b9c89b4988f35b1e70fe6e11a45872e6edd", size = 2140057 },
+ { url = "https://files.pythonhosted.org/packages/e1/95/1eadab39a71d90b725a3f111130004ed3e1cac51af59679b3031656fb8d0/restate_sdk-0.7.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:c7b844344b441574e96240e5dcd66c2f63eb007c7177166737ae2b08f4151102", size = 1790275 },
+ { url = "https://files.pythonhosted.org/packages/46/1a/efc21da414dce524a12f723221df6da3987ea481675faac95f0020212e26/restate_sdk-0.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a39d89a76e4ed84deeae19ac7b6ba2e78400ff5640408ebd385dbdb84c49a385", size = 1714709 },
+ { url = "https://files.pythonhosted.org/packages/df/c2/c39ff95197638f1cfdb6349eb856e2487f6c7a3386019fe42d7f72331af7/restate_sdk-0.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e07b6288f0dd3f0e5cfbda0ec23fa5f46cd9f051736de3dd0488d433879759f5", size = 1973842 },
+ { url = "https://files.pythonhosted.org/packages/49/a7/85b2682f7fe2fbe541186a3222f65fc8a2a08efbfa32c9fa51aa6a45c00c/restate_sdk-0.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:188c9a9dc8b0a9bce59009ef20e5af883db6d0e9ec168e0980dea87e0564a9ea", size = 1939231 },
+ { url = "https://files.pythonhosted.org/packages/f0/70/8f5b20f24e680781d6f7a8d7ff8f5c0519f2261a774f865b24c7bf6d24b5/restate_sdk-0.7.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3d44b27449bc6fba8d8ab72efe01a1cde91b21b80f6b2774ba8532b30839d2d2", size = 2126751 },
+ { url = "https://files.pythonhosted.org/packages/f8/e5/79d64697f6a77980be2a2b93a3519f916ef567e66944a6b3de6bef1b0b5c/restate_sdk-0.7.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d67ba4a9d1cc4831312254ea9cf9dad6b54e7cff0b469d5bd20300c6739d56dd", size = 2140418 },
+ { url = "https://files.pythonhosted.org/packages/c4/d6/279d05df6c7fe2849b3427efdbcd82400396448381cdf77df6a918fbba79/restate_sdk-0.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:8c5b97ac27ed80828f4388768367f254271c39552e25ec943aaacd5f0a12064e", size = 1939642 },
+ { url = "https://files.pythonhosted.org/packages/cb/c2/ab19491e76d3f43976c432b5405fb384a3c2c66f325ec04f5a072a93b4e5/restate_sdk-0.7.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c893e762cdf8ecb994df75ebc3132803748f4ee1387795be12b4e838160268d1", size = 2127237 },
+ { url = "https://files.pythonhosted.org/packages/cf/f5/dca03eb140b00d85441f20f71799357c11ce848f38d7fc48925755763bcb/restate_sdk-0.7.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:24eac2aae43ba5b433d75999024ec59e64717fe0b237d7c891d1fe14627d82ca", size = 2139918 },
+]
+
+[package.optional-dependencies]
+serde = [
+ { name = "dacite" },
+ { name = "pydantic" },
+]
+
[[package]]
name = "rfc3339-validator"
version = "0.1.4"
@@ -5265,6 +5351,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594 },
]
+[[package]]
+name = "wsproto"
+version = "1.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "h11" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226 },
+]
+
[[package]]
name = "xxhash"
version = "3.5.0"