Skip to content

Commit a37dc4e

Browse files
authored
Refine custom endpoint API SSE streaming and configurable A2A RPC endpoints (#216)
1 parent ad0b03c commit a37dc4e

11 files changed

Lines changed: 676 additions & 409 deletions

File tree

src/agentscope_runtime/engine/app/agent_app.py

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
# -*- coding: utf-8 -*-
2-
import asyncio
32
import logging
43
import types
5-
from contextlib import asynccontextmanager
6-
from typing import Optional, Any, Callable, List
4+
from typing import Optional, Callable, List
75

86
import uvicorn
9-
from fastapi import FastAPI
107
from pydantic import BaseModel
118

129
from .base_app import BaseApp
@@ -43,6 +40,7 @@ def __init__(
4340
broker_url: Optional[str] = None,
4441
backend_url: Optional[str] = None,
4542
runner: Optional[Runner] = None,
43+
enable_embedded_worker: bool = False,
4644
**kwargs,
4745
):
4846
"""
@@ -61,6 +59,7 @@ def __init__(
6159
self.after_finish = after_finish
6260
self.broker_url = broker_url
6361
self.backend_url = backend_url
62+
self.enable_embedded_worker = enable_embedded_worker
6463

6564
self._runner = runner
6665
self.custom_endpoints = [] # Store custom endpoints
@@ -79,33 +78,16 @@ def __init__(
7978
response_protocol = ResponseAPIDefaultAdapter()
8079
self.protocol_adapters = [a2a_protocol, response_protocol]
8180

82-
@asynccontextmanager
83-
async def lifespan(app: FastAPI) -> Any:
84-
"""Manage the application lifespan."""
85-
if hasattr(self, "before_start") and self.before_start:
86-
if asyncio.iscoroutinefunction(self.before_start):
87-
await self.before_start(app, **getattr(self, "kwargs", {}))
88-
else:
89-
self.before_start(app, **getattr(self, "kwargs", {}))
90-
yield
91-
if hasattr(self, "after_finish") and self.after_finish:
92-
if asyncio.iscoroutinefunction(self.after_finish):
93-
await self.after_finish(app, **getattr(self, "kwargs", {}))
94-
else:
95-
self.after_finish(app, **getattr(self, "kwargs", {}))
96-
97-
kwargs = {
81+
self._app_kwargs = {
9882
"title": "Agent Service",
9983
"version": __version__,
10084
"description": "Production-ready Agent Service API",
101-
"lifespan": lifespan,
10285
**kwargs,
10386
}
10487

10588
super().__init__(
10689
broker_url=broker_url,
10790
backend_url=backend_url,
108-
**kwargs,
10991
)
11092

11193
# Store custom endpoints and tasks for deployment
@@ -167,7 +149,6 @@ def run(
167149
self,
168150
host="0.0.0.0",
169151
port=8090,
170-
embed_task_processor=False,
171152
**kwargs,
172153
):
173154
"""
@@ -176,7 +157,6 @@ def run(
176157
Args:
177158
host: Host to bind to
178159
port: Port to bind to
179-
embed_task_processor: Whether to embed task processor
180160
**kwargs: Additional keyword arguments
181161
"""
182162
# Build runner
@@ -186,24 +166,7 @@ def run(
186166
logger.info(
187167
"[AgentApp] Starting AgentApp with FastAPIAppFactory...",
188168
)
189-
190-
# Create FastAPI application using the factory
191-
fastapi_app = FastAPIAppFactory.create_app(
192-
runner=self._runner,
193-
endpoint_path=self.endpoint_path,
194-
request_model=self.request_model,
195-
response_type=self.response_type,
196-
stream=self.stream,
197-
before_start=self.before_start,
198-
after_finish=self.after_finish,
199-
mode=DeploymentMode.DAEMON_THREAD,
200-
protocol_adapters=self.protocol_adapters,
201-
custom_endpoints=self.custom_endpoints,
202-
broker_url=self.broker_url,
203-
backend_url=self.backend_url,
204-
enable_embedded_worker=embed_task_processor,
205-
**kwargs,
206-
)
169+
fastapi_app = self.get_fastapi_app(**kwargs)
207170

208171
logger.info(f"[AgentApp] Starting server on {host}:{port}")
209172

@@ -220,6 +183,29 @@ def run(
220183
logger.error(f"[AgentApp] Error while running: {e}")
221184
raise
222185

186+
def get_fastapi_app(self, **kwargs):
187+
"""Get the FastAPI application"""
188+
189+
self._build_runner()
190+
191+
return FastAPIAppFactory.create_app(
192+
runner=self._runner,
193+
endpoint_path=self.endpoint_path,
194+
request_model=self.request_model,
195+
response_type=self.response_type,
196+
stream=self.stream,
197+
before_start=self.before_start,
198+
after_finish=self.after_finish,
199+
mode=DeploymentMode.DAEMON_THREAD,
200+
protocol_adapters=self.protocol_adapters,
201+
custom_endpoints=self.custom_endpoints,
202+
broker_url=self.broker_url,
203+
backend_url=self.backend_url,
204+
enable_embedded_worker=self.enable_embedded_worker,
205+
app_kwargs=self._app_kwargs,
206+
**kwargs,
207+
)
208+
223209
async def deploy(self, deployer: DeployManager, **kwargs):
224210
"""Deploy the agent app with custom endpoints support"""
225211
# Pass custom endpoints and tasks to the deployer
Lines changed: 4 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,29 @@
11
# -*- coding: utf-8 -*-
22
import inspect
33
import logging
4-
import threading
54
from typing import Callable, Optional
65

7-
import uvicorn
8-
from fastapi import FastAPI, Request
9-
from fastapi.responses import StreamingResponse
6+
from fastapi import Request
107

118
from .celery_mixin import CeleryMixin
129

1310
logger = logging.getLogger(__name__)
1411

1512

16-
class BaseApp(FastAPI, CeleryMixin):
13+
class BaseApp(CeleryMixin):
1714
"""
18-
BaseApp extends FastAPI and integrates with Celery
19-
for asynchronous background task execution.
15+
BaseApp integrates Celery for asynchronous background task execution,
16+
and provides FastAPI-like routing for task endpoints.
2017
"""
2118

2219
def __init__(
2320
self,
2421
broker_url: Optional[str] = None,
2522
backend_url: Optional[str] = None,
26-
**kwargs,
2723
):
2824
# Initialize CeleryMixin
2925
CeleryMixin.__init__(self, broker_url, backend_url)
3026

31-
self.server = None
32-
33-
# Initialize FastAPI
34-
FastAPI.__init__(self, **kwargs)
35-
3627
def task(self, path: str, queue: str = "celery"):
3728
"""
3829
Register an asynchronous task endpoint.
@@ -74,108 +65,3 @@ async def get_task(task_id: str):
7465
return func
7566

7667
return decorator
77-
78-
def endpoint(self, path: str):
79-
"""
80-
Unified POST endpoint decorator.
81-
Pure FastAPI functionality, independent of Celery.
82-
Supports:
83-
- Sync functions
84-
- Async functions (coroutines)
85-
- Sync/async generator functions (streaming responses)
86-
"""
87-
88-
def decorator(func: Callable):
89-
is_async_gen = inspect.isasyncgenfunction(func)
90-
is_sync_gen = inspect.isgeneratorfunction(func)
91-
92-
if is_async_gen or is_sync_gen:
93-
# Handle streaming responses
94-
async def _stream_generator(request: Request):
95-
if is_async_gen:
96-
async for chunk in func(request):
97-
yield chunk
98-
else:
99-
for chunk in func(request):
100-
yield chunk
101-
102-
@self.post(path)
103-
async def _wrapped(request: Request):
104-
return StreamingResponse(
105-
_stream_generator(request),
106-
media_type="text/plain",
107-
)
108-
109-
else:
110-
# Handle regular responses
111-
@self.post(path)
112-
async def _wrapped(request: Request):
113-
if inspect.iscoroutinefunction(func):
114-
return await func(request)
115-
else:
116-
return func(request)
117-
118-
return func
119-
120-
return decorator
121-
122-
def run(
123-
self,
124-
host="0.0.0.0",
125-
port=8090,
126-
embed_task_processor=False,
127-
**kwargs,
128-
):
129-
"""
130-
Run FastAPI with uvicorn.
131-
"""
132-
if embed_task_processor:
133-
if self.celery_app is None:
134-
logger.warning(
135-
"[AgentApp] Celery is not configured. "
136-
"Cannot run embedded worker.",
137-
)
138-
else:
139-
logger.warning(
140-
"[AgentApp] embed_task_processor=True: Running "
141-
"task_processor in embedded thread mode. This is "
142-
"intended for development/debug purposes only. In "
143-
"production, run Celery worker in a separate process!",
144-
)
145-
146-
queues = self._registered_queues or {"celery"}
147-
queue_list = ",".join(sorted(queues))
148-
149-
def start_celery_worker():
150-
logger.info(
151-
f"[AgentApp] Embedded worker listening "
152-
f"queues: {queue_list}",
153-
)
154-
self.celery_app.worker_main(
155-
[
156-
"worker",
157-
"--loglevel=INFO",
158-
"-Q",
159-
queue_list,
160-
],
161-
)
162-
163-
threading.Thread(
164-
target=start_celery_worker,
165-
daemon=True,
166-
).start()
167-
logger.info(
168-
"[AgentApp] Embedded task processor started in background "
169-
"thread (DEV mode).",
170-
)
171-
172-
# TODO: Add CLI to main entrypoint to control run/deploy
173-
174-
config = uvicorn.Config(
175-
app=self,
176-
host=host,
177-
port=port,
178-
**kwargs,
179-
)
180-
self.server = uvicorn.Server(config)
181-
self.server.run()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# -*- coding: utf-8 -*-
2+
from .a2a import A2AFastAPIDefaultAdapter

src/agentscope_runtime/engine/deployers/adapter/a2a/a2a_protocol_adapter.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import posixpath
23
from typing import Callable
34

45
from a2a.server.apps import A2AFastAPIApplication
@@ -9,12 +10,16 @@
910
from .a2a_agent_adapter import A2AExecutor
1011
from ..protocol_adapter import ProtocolAdapter
1112

13+
A2A_JSON_RPC_URL = "/a2a"
14+
1215

1316
class A2AFastAPIDefaultAdapter(ProtocolAdapter):
1417
def __init__(self, agent_name, agent_description, **kwargs):
1518
super().__init__(**kwargs)
1619
self._agent_name = agent_name
1720
self._agent_description = agent_description
21+
self._json_rpc_path = kwargs.get("json_rpc_path", A2A_JSON_RPC_URL)
22+
self._base_url = kwargs.get("base_url")
1823

1924
def add_endpoint(self, app, func: Callable, **kwargs):
2025
request_handler = DefaultRequestHandler(
@@ -32,7 +37,14 @@ def add_endpoint(self, app, func: Callable, **kwargs):
3237
http_handler=request_handler,
3338
)
3439

35-
server.add_routes_to_app(app)
40+
server.add_routes_to_app(app, rpc_url=self._json_rpc_path)
41+
42+
def _get_json_rpc_url(self) -> str:
43+
base = self._base_url or "http://127.0.0.1:8000"
44+
return posixpath.join(
45+
base.rstrip("/"),
46+
self._json_rpc_path.lstrip("/"),
47+
)
3648

3749
def get_agent_card(
3850
self,
@@ -62,6 +74,6 @@ def get_agent_card(
6274
description=agent_description,
6375
default_input_modes=["text"],
6476
default_output_modes=["text"],
65-
url="http://127.0.0.1:8090/",
77+
url=self._get_json_rpc_url(),
6678
version="1.0.0",
6779
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# -*- coding: utf-8 -*-
2+
from .response_api_protocol_adapter import ResponseAPIDefaultAdapter

0 commit comments

Comments
 (0)