Skip to content

Commit 7156118

Browse files
committed
fix
1 parent c9edaff commit 7156118

6 files changed

Lines changed: 273 additions & 86 deletions

File tree

cookbook/en/agent_app.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,12 @@ Status query response (completed):
435435

436436
**Important Notes**
437437

438-
1. **In-memory mode**: By default uses in-memory storage; task state is lost on process restart
439-
2. **Persistence**: For production, configure Celery with Redis for task persistence
440-
3. **Storage**: Only stores final response; intermediate streaming events are not saved
441-
4. **Timeout**: Set reasonable timeout based on agent complexity
438+
1. **Dual mode support**:
439+
- **In-memory mode** (default): Task state is lost on restart; suitable for development/testing
440+
- **Celery mode**: Configure `broker_url` and `backend_url` to enable; tasks persisted; suitable for production
441+
2. **Storage**: Only stores final response; intermediate streaming events are not saved
442+
3. **Timeout**: Set reasonable timeout based on agent complexity
443+
4. **Worker requirement**: Celery mode requires running workers (use `enable_embedded_worker=True`)
442444

443445
------
444446

cookbook/zh/agent_app.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,12 @@ while True:
440440

441441
**注意事项**
442442

443-
1. **内存模式**:默认使用 in-memory 模式,进程重启后任务状态会丢失
444-
2. **持久化**:生产环境建议配置 Celery(需要 Redis)实现任务持久化
445-
3. **结果存储**:只存储最终 response,中间流式事件不会被保存
446-
4. **超时设置**:建议根据 agent 复杂度设置合理的超时时间
443+
1. **双模式支持**
444+
- **In-memory 模式**(默认):进程重启后任务状态会丢失,适合开发和测试
445+
- **Celery 模式**:配置 `broker_url``backend_url` 启用,任务持久化,适合生产环境
446+
2. **结果存储**:只存储最终 response,中间流式事件不会被保存
447+
3. **超时设置**:建议根据 agent 复杂度设置合理的超时时间
448+
4. **Worker 需求**:Celery 模式需要启动 worker(可使用 `enable_embedded_worker=True`
447449

448450
------
449451

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ dev = [
6262
"sphinx-autoapi>=3.6.0",
6363
"pytest-mock>=3.15.1",
6464
"sphinxcontrib-mermaid>=1.2.3",
65+
"aiohttp>=3.9.0",
6566
]
6667

6768
ext = [

src/agentscope_runtime/engine/app/agent_app.py

Lines changed: 113 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import platform
88
import shlex
99
import subprocess
10+
import time
1011
import types
12+
import uuid
1113
from contextlib import asynccontextmanager, AsyncExitStack
1214
from typing import Any, Callable, Dict, List, Optional, Type
1315

@@ -173,6 +175,7 @@ def __init__(
173175
self.enable_stream_task = enable_stream_task
174176
self.stream_task_queue = stream_task_queue
175177
self.stream_task_timeout = stream_task_timeout
178+
self._stream_query_celery_task: Optional[Callable] = None
176179

177180
self._query_handler: Optional[Callable] = None
178181
self._init_handler: Optional[Callable] = None
@@ -249,6 +252,7 @@ async def _internal_framework_lifespan(self, app: FastAPI):
249252
"""
250253
# pylint: disable=too-many-branches
251254
self._build_runner()
255+
cleanup_task = None
252256
try:
253257
# aexit any possible running instances before set up
254258
# runner
@@ -272,9 +276,22 @@ async def _internal_framework_lifespan(self, app: FastAPI):
272276
if self.enable_embedded_worker and self.celery_app:
273277
self.start_embedded_celery_worker()
274278

279+
if self.enable_stream_task:
280+
cleanup_task = asyncio.create_task(
281+
self._task_cleanup_worker(),
282+
)
283+
logger.info("Started task cleanup worker")
284+
275285
yield
276286

277287
finally:
288+
if cleanup_task:
289+
cleanup_task.cancel()
290+
try:
291+
await cleanup_task
292+
except asyncio.CancelledError:
293+
pass
294+
278295
if self.after_finish:
279296
try:
280297
if asyncio.iscoroutinefunction(self.after_finish):
@@ -406,12 +423,92 @@ async def root():
406423

407424
self._add_process_control_endpoints()
408425

409-
def _add_stream_query_task_endpoint(self):
426+
async def _cleanup_expired_tasks(self):
427+
"""
428+
Remove completed/failed tasks older than TTL.
429+
430+
Returns:
431+
Number of tasks cleaned up
432+
"""
433+
now = time.time()
434+
ttl_seconds = 3600 # 1 hour
435+
436+
expired = []
437+
for task_id, info in self.active_tasks.items():
438+
status = info.get("status")
439+
440+
if status in ["completed", "failed"]:
441+
finished_at = info.get("completed_at") or info.get(
442+
"failed_at",
443+
)
444+
if finished_at and (now - finished_at) > ttl_seconds:
445+
expired.append(task_id)
446+
447+
for task_id in expired:
448+
del self.active_tasks[task_id]
449+
if hasattr(self, "task_locks") and task_id in self.task_locks:
450+
del self.task_locks[task_id]
451+
452+
if expired:
453+
logger.info(
454+
f"Cleaned up {len(expired)} expired tasks. "
455+
f"Active tasks: {len(self.active_tasks)}",
456+
)
457+
458+
return len(expired)
459+
460+
async def _task_cleanup_worker(self):
461+
"""Background worker to cleanup expired tasks periodically."""
462+
while True:
463+
try:
464+
await asyncio.sleep(300) # Run every 5 minutes
465+
await self._cleanup_expired_tasks()
466+
except asyncio.CancelledError:
467+
logger.info("Task cleanup worker stopped")
468+
break
469+
except Exception as e:
470+
logger.error(f"Task cleanup failed: {e}")
471+
472+
def _create_stream_query_wrapper(self):
473+
"""
474+
Create a wrapper function for stream_query that collects only
475+
the final response.
476+
477+
This wrapper is used by Celery to execute stream_query as a
478+
background task.
479+
"""
480+
481+
async def stream_query_wrapper(request: dict):
482+
"""Wrapper that collects only final response from stream_query"""
483+
final_response = None
484+
485+
async for event in self._runner.stream_query(request):
486+
if hasattr(event, "model_dump"):
487+
final_response = event.model_dump()
488+
elif hasattr(event, "dict"):
489+
final_response = event.dict()
490+
else:
491+
final_response = {"data": str(event)}
492+
493+
return final_response
494+
495+
return stream_query_wrapper
496+
497+
def _add_stream_query_task_endpoint(self) -> None:
410498
"""
411499
Add background task endpoints for stream_query.
412-
Creates POST /process/task and GET /process/task/{task_id}.
413500
501+
Creates POST /process/task and GET /process/task/{task_id}.
414502
Design: Only stores the final response, not intermediate events.
503+
Supports both Celery and in-memory modes.
504+
505+
Args:
506+
self (AgentApp): The application instance on which to register
507+
the task endpoints.
508+
509+
Returns:
510+
None: This method registers routes on the application and does
511+
not return a value.
415512
"""
416513
if not self.enable_stream_task:
417514
logger.debug("Stream task disabled, skipping task endpoint setup")
@@ -447,22 +544,27 @@ def _add_stream_query_task_endpoint(self):
447544
@UnifiedRoutingMixin.internal_route
448545
async def submit_stream_query_task(request: dict):
449546
"""Submit stream_query as background task"""
450-
import uuid
451-
452547
task_id = str(uuid.uuid4())
453548

454549
if self.celery_app:
550+
if self._stream_query_celery_task is None:
551+
wrapper_func = self._create_stream_query_wrapper()
552+
self._stream_query_celery_task = self.register_celery_task(
553+
wrapper_func,
554+
self.stream_task_queue,
555+
)
556+
557+
result = self._stream_query_celery_task.delay(request)
558+
455559
return {
456-
"error": (
457-
"Celery mode not yet implemented for stream tasks"
458-
),
459-
"suggestion": (
460-
"Use in-memory mode or contribute implementation"
560+
"task_id": result.id,
561+
"status": "submitted",
562+
"queue": self.stream_task_queue,
563+
"message": (
564+
"Stream query task submitted to Celery successfully"
461565
),
462566
}
463567
else:
464-
import time
465-
466568
self.active_tasks[task_id] = {
467569
"task_id": task_id,
468570
"status": "submitted",

src/agentscope_runtime/engine/deployers/utils/service_utils/routing/task_engine_mixin.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def init_task_engine(
1919
self.celery_app = None
2020
self.active_tasks: Dict[str, Dict[str, Any]] = {}
2121
self._registered_queues: set[str] = set()
22+
self.task_locks: Dict[str, asyncio.Lock] = {}
23+
self._tasks_lock: Optional[asyncio.Lock] = None
2224

2325
if broker_url and backend_url:
2426
try:
@@ -156,6 +158,24 @@ def _run_celery_task_processor(
156158

157159
self.celery_app.worker_main(cmd)
158160

161+
async def _get_task_lock(self, task_id: str) -> asyncio.Lock:
162+
"""
163+
Get or create a lock for a specific task.
164+
165+
Args:
166+
task_id: Task identifier
167+
168+
Returns:
169+
asyncio.Lock for the specified task
170+
"""
171+
if self._tasks_lock is None:
172+
self._tasks_lock = asyncio.Lock()
173+
174+
async with self._tasks_lock:
175+
if task_id not in self.task_locks:
176+
self.task_locks[task_id] = asyncio.Lock()
177+
return self.task_locks[task_id]
178+
159179
async def execute_background_task(
160180
self,
161181
task_id: str,
@@ -242,54 +262,89 @@ async def execute_stream_query_task(
242262
243263
Returns:
244264
Final response event as dict
265+
266+
Raises:
267+
TimeoutError: If task exceeds specified timeout
268+
RuntimeError: If stream yields no events
245269
"""
246270
# pylint:disable=unused-argument
271+
task_lock = await self._get_task_lock(task_id)
272+
247273
try:
248-
self.active_tasks[task_id].update(
249-
{
250-
"status": "running",
251-
"started_at": time.time(),
252-
},
253-
)
274+
async with task_lock:
275+
self.active_tasks[task_id].update(
276+
{
277+
"status": "running",
278+
"started_at": time.time(),
279+
},
280+
)
254281

255282
final_response = None
256283
start_time = time.time()
284+
event_count = 0
285+
286+
async def stream_with_collection():
287+
nonlocal final_response, event_count
288+
async for event in stream_func(request):
289+
event_count += 1
290+
291+
if hasattr(event, "model_dump"):
292+
final_response = event.model_dump()
293+
elif hasattr(event, "dict"):
294+
final_response = event.dict()
295+
else:
296+
final_response = {"data": str(event)}
297+
298+
if timeout is not None:
299+
await asyncio.wait_for(
300+
stream_with_collection(),
301+
timeout=timeout,
302+
)
303+
else:
304+
await stream_with_collection()
257305

258-
async for event in stream_func(request):
259-
if timeout and (time.time() - start_time) > timeout:
260-
raise TimeoutError(
261-
f"Task {task_id} exceeded timeout of {timeout}s",
262-
)
263-
264-
if hasattr(event, "model_dump"):
265-
final_response = event.model_dump()
266-
elif hasattr(event, "dict"):
267-
final_response = event.dict()
268-
else:
269-
final_response = {"data": str(event)}
306+
if event_count == 0 or final_response is None:
307+
raise RuntimeError(
308+
f"Stream function yielded no events for task {task_id}",
309+
)
270310

271311
elapsed = time.time() - start_time
272312

273-
self.active_tasks[task_id].update(
274-
{
275-
"status": "completed",
276-
"result": final_response,
277-
"completed_at": time.time(),
278-
"elapsed_time": elapsed,
279-
},
280-
)
313+
async with task_lock:
314+
self.active_tasks[task_id].update(
315+
{
316+
"status": "completed",
317+
"result": final_response,
318+
"completed_at": time.time(),
319+
"elapsed_time": elapsed,
320+
},
321+
)
281322

282323
return final_response
283324

325+
except asyncio.TimeoutError:
326+
async with task_lock:
327+
self.active_tasks[task_id].update(
328+
{
329+
"status": "failed",
330+
"error": f"Task exceeded timeout of {timeout}s",
331+
"error_type": "TimeoutError",
332+
"failed_at": time.time(),
333+
},
334+
)
335+
raise
336+
284337
except Exception as e:
285-
self.active_tasks[task_id].update(
286-
{
287-
"status": "failed",
288-
"error": str(e),
289-
"error_type": type(e).__name__,
290-
"failed_at": time.time(),
291-
},
292-
)
338+
async with task_lock:
339+
self.active_tasks[task_id].update(
340+
{
341+
"status": "failed",
342+
"error": str(e),
343+
"error_type": type(e).__name__,
344+
"failed_at": time.time(),
345+
},
346+
)
347+
raise
293348

294349
def get_task_status(self, task_id: str):
295350
# pylint:disable=too-many-return-statements

0 commit comments

Comments
 (0)