77import platform
88import shlex
99import subprocess
10+ import time
1011import types
12+ import uuid
1113from contextlib import asynccontextmanager , AsyncExitStack
1214from typing import Any , Callable , Dict , List , Optional , Type
1315
@@ -173,6 +175,7 @@ def __init__(
173175 self .enable_stream_task = enable_stream_task
174176 self .stream_task_queue = stream_task_queue
175177 self .stream_task_timeout = stream_task_timeout
178+ self ._stream_query_celery_task : Optional [Callable ] = None
176179
177180 self ._query_handler : Optional [Callable ] = None
178181 self ._init_handler : Optional [Callable ] = None
@@ -249,6 +252,7 @@ async def _internal_framework_lifespan(self, app: FastAPI):
249252 """
250253 # pylint: disable=too-many-branches
251254 self ._build_runner ()
255+ cleanup_task = None
252256 try :
253257 # aexit any possible running instances before set up
254258 # runner
@@ -272,9 +276,22 @@ async def _internal_framework_lifespan(self, app: FastAPI):
272276 if self .enable_embedded_worker and self .celery_app :
273277 self .start_embedded_celery_worker ()
274278
279+ if self .enable_stream_task :
280+ cleanup_task = asyncio .create_task (
281+ self ._task_cleanup_worker (),
282+ )
283+ logger .info ("Started task cleanup worker" )
284+
275285 yield
276286
277287 finally :
288+ if cleanup_task :
289+ cleanup_task .cancel ()
290+ try :
291+ await cleanup_task
292+ except asyncio .CancelledError :
293+ pass
294+
278295 if self .after_finish :
279296 try :
280297 if asyncio .iscoroutinefunction (self .after_finish ):
@@ -406,12 +423,92 @@ async def root():
406423
407424 self ._add_process_control_endpoints ()
408425
409- def _add_stream_query_task_endpoint (self ):
426+ async def _cleanup_expired_tasks (self ):
427+ """
428+ Remove completed/failed tasks older than TTL.
429+
430+ Returns:
431+ Number of tasks cleaned up
432+ """
433+ now = time .time ()
434+ ttl_seconds = 3600 # 1 hour
435+
436+ expired = []
437+ for task_id , info in self .active_tasks .items ():
438+ status = info .get ("status" )
439+
440+ if status in ["completed" , "failed" ]:
441+ finished_at = info .get ("completed_at" ) or info .get (
442+ "failed_at" ,
443+ )
444+ if finished_at and (now - finished_at ) > ttl_seconds :
445+ expired .append (task_id )
446+
447+ for task_id in expired :
448+ del self .active_tasks [task_id ]
449+ if hasattr (self , "task_locks" ) and task_id in self .task_locks :
450+ del self .task_locks [task_id ]
451+
452+ if expired :
453+ logger .info (
454+ f"Cleaned up { len (expired )} expired tasks. "
455+ f"Active tasks: { len (self .active_tasks )} " ,
456+ )
457+
458+ return len (expired )
459+
460+ async def _task_cleanup_worker (self ):
461+ """Background worker to cleanup expired tasks periodically."""
462+ while True :
463+ try :
464+ await asyncio .sleep (300 ) # Run every 5 minutes
465+ await self ._cleanup_expired_tasks ()
466+ except asyncio .CancelledError :
467+ logger .info ("Task cleanup worker stopped" )
468+ break
469+ except Exception as e :
470+ logger .error (f"Task cleanup failed: { e } " )
471+
472+ def _create_stream_query_wrapper (self ):
473+ """
474+ Create a wrapper function for stream_query that collects only
475+ the final response.
476+
477+ This wrapper is used by Celery to execute stream_query as a
478+ background task.
479+ """
480+
481+ async def stream_query_wrapper (request : dict ):
482+ """Wrapper that collects only final response from stream_query"""
483+ final_response = None
484+
485+ async for event in self ._runner .stream_query (request ):
486+ if hasattr (event , "model_dump" ):
487+ final_response = event .model_dump ()
488+ elif hasattr (event , "dict" ):
489+ final_response = event .dict ()
490+ else :
491+ final_response = {"data" : str (event )}
492+
493+ return final_response
494+
495+ return stream_query_wrapper
496+
497+ def _add_stream_query_task_endpoint (self ) -> None :
410498 """
411499 Add background task endpoints for stream_query.
412- Creates POST /process/task and GET /process/task/{task_id}.
413500
501+ Creates POST /process/task and GET /process/task/{task_id}.
414502 Design: Only stores the final response, not intermediate events.
503+ Supports both Celery and in-memory modes.
504+
505+ Args:
506+ self (AgentApp): The application instance on which to register
507+ the task endpoints.
508+
509+ Returns:
510+ None: This method registers routes on the application and does
511+ not return a value.
415512 """
416513 if not self .enable_stream_task :
417514 logger .debug ("Stream task disabled, skipping task endpoint setup" )
@@ -447,22 +544,27 @@ def _add_stream_query_task_endpoint(self):
447544 @UnifiedRoutingMixin .internal_route
448545 async def submit_stream_query_task (request : dict ):
449546 """Submit stream_query as background task"""
450- import uuid
451-
452547 task_id = str (uuid .uuid4 ())
453548
454549 if self .celery_app :
550+ if self ._stream_query_celery_task is None :
551+ wrapper_func = self ._create_stream_query_wrapper ()
552+ self ._stream_query_celery_task = self .register_celery_task (
553+ wrapper_func ,
554+ self .stream_task_queue ,
555+ )
556+
557+ result = self ._stream_query_celery_task .delay (request )
558+
455559 return {
456- "error " : (
457- "Celery mode not yet implemented for stream tasks"
458- ) ,
459- "suggestion " : (
460- "Use in-memory mode or contribute implementation "
560+ "task_id " : result . id ,
561+ "status" : "submitted" ,
562+ "queue" : self . stream_task_queue ,
563+ "message " : (
564+ "Stream query task submitted to Celery successfully "
461565 ),
462566 }
463567 else :
464- import time
465-
466568 self .active_tasks [task_id ] = {
467569 "task_id" : task_id ,
468570 "status" : "submitted" ,
0 commit comments