22import inspect
33import threading
44import logging
5+ from contextlib import asynccontextmanager
56from typing import Union , Callable , get_type_hints , Generator , AsyncGenerator , Iterator , AsyncIterator
67from fastapi import APIRouter , FastAPI , Response
78from fastapi .responses import JSONResponse
@@ -34,6 +35,7 @@ def __init__(
3435 prefix : str = "" , # "/api",
3536 max_upload_file_size_mb : float = None ,
3637 job_queue = None ,
38+ lifespan = None ,
3739 * args ,
3840 ** kwargs ):
3941 """
@@ -46,12 +48,17 @@ def __init__(
4648 prefix: The API route prefix
4749 max_upload_file_size_mb: Maximum file size in MB for uploads
4850 job_queue: Optional custom JobQueue implementation
51+ lifespan: Optional async context manager for custom startup/shutdown logic
4952 args: Additional arguments
5053 kwargs: Additional keyword arguments
5154 """
55+ # Extract user-provided lifespan (explicit param or kwarg) before parent init
56+ user_lifespan = lifespan or kwargs .pop ('lifespan' , None )
57+
5258 # Initialize parent classes
5359 api_router_params = inspect .signature (APIRouter .__init__ ).parameters
5460 api_router_kwargs = {k : kwargs .get (k ) for k in api_router_params if k in kwargs }
61+ api_router_kwargs .pop ('lifespan' , None ) # handled via composed lifespan below
5562
5663 APIRouter .__init__ (self , ** api_router_kwargs )
5764 _BaseBackend .__init__ (self , title = title , summary = summary , * args , ** kwargs )
@@ -61,27 +68,35 @@ def __init__(
6168
6269 self .status = SERVER_HEALTH .INITIALIZING
6370
71+ # Registry for functions that workers can execute. Keys are function names.
72+ self ._job_func_registry : dict = {}
73+ # Stop event and thread handle for in-process worker (dev mode)
74+ self ._worker_stop_event = threading .Event ()
75+ self ._worker_thread : threading .Thread | None = None
76+ self ._logger = logging .getLogger (__name__ )
77+
78+ # Build a composed lifespan that merges internal worker hooks with the user-provided lifespan
79+ combined_lifespan = self ._build_lifespan (user_lifespan )
80+
6481 # Create or use provided FastAPI app
6582 if app is None :
6683 app = FastAPI (
6784 title = self .title ,
6885 summary = self .summary ,
69- contact = {"name" : "SocAIty" , "url" : "https://www.socaity.ai" }
86+ contact = {"name" : "SocAIty" , "url" : "https://www.socaity.ai" },
87+ lifespan = combined_lifespan ,
7088 )
89+ else :
90+ # Existing app: replace its lifespan with our composed version
91+ app .router .lifespan_context = combined_lifespan
7192
7293 self .app : FastAPI = app
7394 self .prefix = prefix
7495 self .add_standard_routes ()
7596
76- # Registry for functions that workers can execute. Keys are function names.
77- self ._job_func_registry : dict = {}
78- # Stop event and thread handle for in-process worker (dev mode)
79- self ._worker_stop_event = threading .Event ()
80- self ._worker_thread : threading .Thread | None = None
81- self ._logger = logging .getLogger (__name__ )
8297 self ._endpoint_configurator = FastApiEndpointConfigurator (self )
8398
84- # excpetion handling
99+ # Exception handling
85100 _FastAPIExceptionHandler .__init__ (self )
86101 if not getattr (self .app .state , "_socaity_exception_handler_added" , False ):
87102 self .app .add_exception_handler (Exception , self .global_exception_handler )
@@ -91,50 +106,69 @@ def __init__(
91106 self ._orig_openapi_func = self .app .openapi
92107 self .app .openapi = self .custom_openapi
93108
94- # Start in-process worker on FastAPI startup (dev convenience).
95- # Only start if a job_queue with `start_worker` exists.
96- if not getattr (self .app .state , "_socaity_worker_hooks_added" , False ):
97- def _startup ():
98- try :
99- if self .job_queue and hasattr (self .job_queue , "start_worker" ):
100- # Start worker in a daemon thread so it doesn't block uvicorn
101- def _run ():
102- try :
103- self .job_queue .start_worker (
104- func_registry = self ._job_func_registry ,
105- worker_name = "api-worker" ,
106- stop_event = self ._worker_stop_event ,
107- )
108- except Exception :
109- self ._logger .exception ("Worker thread exited with exception" )
110-
111- t = threading .Thread (target = _run , daemon = True )
112- t .start ()
113- self ._worker_thread = t
114- except Exception :
115- self ._logger .exception ("Failed to start in-process worker on startup" )
116-
117- def _shutdown ():
118- try :
119- # Signal local worker to stop
109+ # ------------------------------------------------------------------
110+ # Lifespan & worker lifecycle
111+ # ------------------------------------------------------------------
112+
113+ def _build_lifespan (self , user_lifespan = None ):
114+ """
115+ Build a composed lifespan context manager that runs:
116+ 1. Internal worker startup
117+ 2. User-provided lifespan (if any)
118+ 3. Internal worker shutdown on exit
119+ """
120+ router_self = self # capture for closure
121+
122+ @asynccontextmanager
123+ async def _combined_lifespan (app ):
124+ router_self ._start_background_worker ()
125+ try :
126+ if user_lifespan :
127+ async with user_lifespan (app ):
128+ yield
129+ else :
130+ yield
131+ finally :
132+ router_self ._stop_background_worker ()
133+
134+ return _combined_lifespan
135+
136+ def _start_background_worker (self ):
137+ """Start the in-process job queue worker in a daemon thread (dev convenience)."""
138+ try :
139+ if self .job_queue and hasattr (self .job_queue , "start_worker" ):
140+ def _run ():
120141 try :
121- self ._worker_stop_event .set ()
142+ self .job_queue .start_worker (
143+ func_registry = self ._job_func_registry ,
144+ worker_name = "api-worker" ,
145+ stop_event = self ._worker_stop_event ,
146+ )
122147 except Exception :
123- pass
124-
125- # Call job_queue.shutdown if available
126- if self .job_queue and hasattr (self .job_queue , "shutdown" ):
127- try :
128- self .job_queue .shutdown ()
129- except Exception :
130- self ._logger .exception ("Error shutting down job queue" )
131- except Exception :
132- self ._logger .exception ("Error during worker shutdown handler" )
133-
134- # Register handlers
135- self .app .add_event_handler ("startup" , _startup )
136- self .app .add_event_handler ("shutdown" , _shutdown )
137- self .app .state ._socaity_worker_hooks_added = True
148+ self ._logger .exception ("Worker thread exited with exception" )
149+
150+ thread = threading .Thread (target = _run , daemon = True )
151+ thread .start ()
152+ self ._worker_thread = thread
153+ except Exception :
154+ self ._logger .exception ("Failed to start in-process worker on startup" )
155+
156+ def _stop_background_worker (self ):
157+ """Signal the background worker to stop and shut down the job queue."""
158+ try :
159+ self ._worker_stop_event .set ()
160+ except Exception :
161+ pass
162+
163+ if self .job_queue and hasattr (self .job_queue , "shutdown" ):
164+ try :
165+ self .job_queue .shutdown ()
166+ except Exception :
167+ self ._logger .exception ("Error shutting down job queue" )
168+
169+ # ------------------------------------------------------------------
170+ # Standard routes
171+ # ------------------------------------------------------------------
138172
139173 def add_standard_routes (self ):
140174 """Add standard API routes for status and health checks."""
0 commit comments